예제 #1
0
def main(unused_argv):
    if not tf.gfile.IsDirectory(FLAGS.save_dir):
        tf.gfile.MakeDirs(FLAGS.save_dir)

    with tf.Graph().as_default():
        datasets = dataset.get_datasets()
        dataset_content = datasets[FLAGS.dataset](FLAGS.dataset,
                                                  FLAGS.dataset_split, 4,
                                                  128 * 1024)
        batched = dataset_content.batch(64)
        batch_op = batched.make_one_shot_iterator().get_next()

        with tf.Session() as session:
            data = session.run(batch_op)[0]

    if not FLAGS.grid:
        for i in range(data.shape[0]):
            save_image(data[i], FLAGS.dataset + "_%d.png" % i)
    else:
        grid_im = []
        for i in range(8):
            im_row = []
            for j in range(8):
                im_row.append(data[i * 8 + j])
            grid_im.append(np.concatenate(im_row, axis=1))
        grid_im = np.concatenate(grid_im, axis=0)
        save_image(grid_im, FLAGS.dataset + "_grid.png")
예제 #2
0
def main(unused_argv):
  gan_lib.MODELS.update({
      "MultiGAN": multi_gan.MultiGAN,
      "MultiGANBackground": multi_gan_background.MultiGANBackground
  })
  params.PARAMETERS.update({
      "MultiGAN": multi_gan.MultiGANHyperParams,
      "MultiGANBackground": multi_gan_background.MultiGANBackgroundHyperParams
  })

  gan_lib.DATASETS.update(dataset.get_datasets())
  params.DATASET_PARAMS.update(dataset.get_dataset_params())

  task_workdir = FLAGS.eval_task_workdir

  task = simple_task_pb2.Task()
  with open(os.path.join(task_workdir, "task"), "r") as f:
    text_format.Parse(f.read(), task)

  options = task_utils.ParseOptions(task)

  out_dir = os.path.join(FLAGS.out_dir, GetModelDir(options))
  if not tf.gfile.IsDirectory(out_dir):
    tf.gfile.MakeDirs(out_dir)

  task_string = text_format.MessageToString(task)
  print("\nWill evaluate task\n%s\n\n", task_string)

  EvalTask(options, task_workdir, out_dir)
예제 #3
0
def AddGansAndDatasets():
    """Injects MultiGAN models, parameters and datasets.

    This code injects the GAN model and its default parameters to the framework.
    Must be run just after the main.
    """
    gan_lib.MODELS.update({
        "MultiGAN":
        multi_gan.MultiGAN,
        "MultiGANBackground":
        multi_gan_background.MultiGANBackground
    })
    params.PARAMETERS.update({
        "MultiGAN":
        multi_gan.MultiGANHyperParams,
        "MultiGANBackground":
        multi_gan_background.MultiGANBackgroundHyperParams
    })
    eval_gan_lib.SUPPORTED_GANS.extend(["MultiGAN", "MultiGANBackground"])
    eval_gan_lib.DEFAULT_VALUES.update({
        "k": -1,
        "aggregate": "none",
        "embedding_dim": -1,
        "n_blocks": -1,
        "share_block_weights": False,
        "n_heads": -1,
        "background_interaction": False,
    })

    gan_lib.DATASETS.update(dataset.get_datasets())
    params.DATASET_PARAMS.update(dataset.get_dataset_params())