Ejemplo n.º 1
0
def run_with_options(options, task_workdir, progress_reporter=None):
    """Runs the task with arbitrary options."""
    checkpoint_dir = os.path.join(task_workdir, "checkpoint")
    tfprofile_dir = os.path.join(task_workdir, "tfprofile")
    result_dir = os.path.join(task_workdir, "result")
    gan_log_dir = os.path.join(task_workdir, "logs")

    gan_type = options["gan_type"]
    dataset = options["dataset"]

    logging.info(
        "Running tasks with gan_type: %s dataset: %s with parameters %s",
        gan_type, dataset, str(options))
    logging.info("Checkpoint dir: %s result_dir: %s gan_log_dir: %s",
                 checkpoint_dir, result_dir, gan_log_dir)

    ops.check_folder(checkpoint_dir)
    ops.check_folder(tfprofile_dir)
    ops.check_folder(result_dir)
    ops.check_folder(gan_log_dir)

    # Set the dataset shuffling seed if specified in options.
    dataset_seed = DEFAULT_DATASET_SEED
    if "dataset_seed" in options:
        logging.info("Seeting dataset seed to %d", options["dataset_seed"])
        dataset_seed = options["dataset_seed"]

    dataset_content = load_dataset(dataset, split_name="train")
    dataset_content = dataset_content.repeat().shuffle(10000,
                                                       seed=dataset_seed)

    with tf.Graph().as_default():
        if "tf_seed" in options:
            seed = options["tf_seed"]
            logging.info("Setting tf and np random seed to %d", seed)
            tf.set_random_seed(seed)
            np.random.seed(seed)

        with profile_context(tfprofile_dir):
            with tf.Session(config=tf.ConfigProto(
                    allow_soft_placement=True)) as sess:
                gan = create_gan(gan_type=gan_type,
                                 dataset=dataset,
                                 dataset_content=dataset_content,
                                 options=options,
                                 checkpoint_dir=checkpoint_dir,
                                 result_dir=result_dir,
                                 gan_log_dir=gan_log_dir)
                gan.build_model()
                print " [*] Training started!"
                gan.train(sess, progress_reporter)
                print " [*] Training finished!"
Ejemplo n.º 2
0
 def _save_samples(self, step, sess, filename_suffix, z_distribution=None):
   if z_distribution is None:
     z_distribution = self.z_generator
   z_sample = z_distribution(self.batch_size, self.z_dim)
   grid_shape = self._image_grid_shape()
   samples = sess.run(self.fake_images_merged,
                      feed_dict={self.z: z_sample})
   samples = samples.reshape((grid_shape[0] * self.input_height,
                              grid_shape[1] * self.input_width, -1)).squeeze()
   out_folder = ops.check_folder(os.path.join(self.result_dir, self.model_dir))
   full_path = os.path.join(out_folder, filename_suffix)
   ops.save_images(samples, full_path)
Ejemplo n.º 3
0
 def maybe_save_samples(self, idx):
   """Saves training results every 5000 steps."""
   if np.mod(idx, 5000) != 0:
     return
   z_sample = self.z_generator(self.batch_size, self.z_dim)
   samples = self.sess.run(self.fake_images_merged,
                           feed_dict={self.z: z_sample})
   samples = samples.reshape(
       (8 * self.input_height, 8 * self.input_width, -1)).squeeze()
   out_folder = ops.check_folder(os.path.join(self.result_dir, self.model_dir))
   suffix = "%s_train_%04d.png" % (self.model_name, idx)
   full_path = os.path.join(out_folder, suffix)
   ops.save_images(samples, full_path)
Ejemplo n.º 4
0
 def visualize_results(self, step, z_distribution=None):
   """Generates and stores a set of fake images."""
   if z_distribution is None:
     z_distribution = self.z_generator
   z_sample = z_distribution(self.batch_size, self.z_dim)
   samples = self.sess.run(self.fake_images_merged,
                           feed_dict={self.z: z_sample})
   samples = samples.reshape(
       (8 * self.input_height, 8 * self.input_width, -1)).squeeze()
   out_folder = ops.check_folder(os.path.join(self.result_dir, self.model_dir))
   suffix = "%s_step%03d_test_all_classes.png" % (self.model_name, step)
   full_path = os.path.join(out_folder, suffix)
   ops.save_images(samples, full_path)
Ejemplo n.º 5
0
 def maybe_save_samples(self, idx):
     """Saves training results every 5000 steps."""
     if np.mod(idx, 5000) != 0:
         return
     z_sample = self.z_generator(self.batch_size, self.z_dim)
     samples = self.sess.run(self.fake_images_merged,
                             feed_dict={self.z: z_sample})
     samples = samples.reshape(
         (8 * self.input_height, 8 * self.input_width, -1)).squeeze()
     out_folder = ops.check_folder(
         os.path.join(self.result_dir, self.model_dir))
     suffix = "%s_train_%04d.png" % (self.model_name, idx)
     full_path = os.path.join(out_folder, suffix)
     ops.save_images(samples, full_path)
Ejemplo n.º 6
0
 def visualize_results(self, step, z_distribution=None):
     """Generates and stores a set of fake images."""
     if z_distribution is None:
         z_distribution = self.z_generator
     z_sample = z_distribution(self.batch_size, self.z_dim)
     samples = self.sess.run(self.fake_images_merged,
                             feed_dict={self.z: z_sample})
     samples = samples.reshape(
         (8 * self.input_height, 8 * self.input_width, -1)).squeeze()
     out_folder = ops.check_folder(
         os.path.join(self.result_dir, self.model_dir))
     suffix = "%s_step%03d_test_all_classes.png" % (self.model_name, step)
     full_path = os.path.join(out_folder, suffix)
     ops.save_images(samples, full_path)
Ejemplo n.º 7
0
def run_with_options(options,
                     task_workdir,
                     progress_reporter=None,
                     warm_start_from=None):
    """Runs the task with arbitrary options.

  Args:
    options: Dictionary with meta and hyper parameters.
    task_workdir: Directory to save logs, checkpoints, samples etc. If the
        subdirectory "checkpoint" contains checkpoints the method will attempt
        to load the latest checkpoint.
    progress_reporter: Callback function to report progress (parameters:
        step, steps_per_sec, progress, eta_minutes).
    warm_start_from: `tf.estimator.WarmStartSettings`. Only supported for
        estimator training.

  Raises:
    ValueError: For infeasible combinations of options.
  """
    checkpoint_dir = os.path.join(task_workdir, "checkpoint")
    tfprofile_dir = os.path.join(task_workdir, "tfprofile")
    result_dir = os.path.join(task_workdir, "result")
    gan_log_dir = os.path.join(task_workdir, "logs")

    gan_type = options["gan_type"]
    dataset = options["dataset"]

    logging.info(
        "Running tasks with gan_type: %s dataset: %s with parameters %s",
        gan_type, dataset, str(options))
    logging.info("Checkpoint dir: %s result_dir: %s gan_log_dir: %s",
                 checkpoint_dir, result_dir, gan_log_dir)

    ops.check_folder(checkpoint_dir)
    ops.check_folder(tfprofile_dir)
    ops.check_folder(result_dir)
    ops.check_folder(gan_log_dir)

    if "tf_seed" in options:
        logging.info("Setting np random seed to %s", options["tf_seed"])
        np.random.seed(options["tf_seed"])

    # Set the dataset shuffling seed if specified in options.
    dataset_seed = DEFAULT_DATASET_SEED
    if "dataset_seed" in options:
        logging.info("Seeting dataset seed to %d", options["dataset_seed"])
        dataset_seed = options["dataset_seed"]

    dataset_content = load_dataset(dataset, split_name="train")
    dataset_content = dataset_content.repeat().shuffle(10000,
                                                       seed=dataset_seed)

    if options.get("use_estimator", options.get("use_tpu", False)):
        # Estimator mode supports CPU, GPU and TPU training.
        gan = create_gan(gan_type=gan_type,
                         dataset=dataset,
                         dataset_content=dataset_content,
                         options=options,
                         gan_log_dir=gan_log_dir,
                         result_dir=result_dir,
                         checkpoint_dir=checkpoint_dir)
        config = tf.contrib.tpu.RunConfig(
            model_dir=checkpoint_dir,
            tf_random_seed=options.get("tf_seed", None),
            save_checkpoints_steps=int(options["save_checkpoint_steps"]),
            keep_checkpoint_max=gan.max_checkpoints_to_keep,
            master=FLAGS.master,
            evaluation_master=FLAGS.master,
            tpu_config=tf.contrib.tpu.TPUConfig(
                iterations_per_loop=FLAGS.iterations_per_loop))
        print(" [*] Training started!")
        gan.train_with_estimator(config=config,
                                 warm_start_from=warm_start_from)
        print(" [*] Training finished!")
    else:
        if options.get("use_tpu", False):
            raise ValueError(
                "TPU experiments must run with use_estimator=True.")
        if warm_start_from:
            raise ValueError("Warm starting is only supported for estimator.")
        with tf.Graph().as_default():
            if "tf_seed" in options:
                logging.info("Setting tf random seed to %s",
                             options["tf_seed"])
                tf.set_random_seed(options["tf_seed"])
                # NumPy random seed is already set above.
            with profile_context(tfprofile_dir):
                config = tf.ConfigProto(allow_soft_placement=True)
                with tf.Session(config=config) as sess:
                    gan = create_gan(gan_type=gan_type,
                                     dataset=dataset,
                                     dataset_content=dataset_content,
                                     options=options,
                                     checkpoint_dir=checkpoint_dir,
                                     result_dir=result_dir,
                                     gan_log_dir=gan_log_dir)
                    gan.build_model()
                    print(" [*] Training started!")
                    gan.train(sess, progress_reporter)
                    print(" [*] Training finished!")