コード例 #1
0
def main(unused_argv):
    logging.set_verbosity(tf.logging.INFO)

    if FLAGS.use_mnist:
        reader = readers.MnistReader()
    else:
        reader = readers.FaceReader()

    if FLAGS.output_dir is "":
        raise ValueError("'output_file' was not specified. "
                         "Unable to continue with generate_and_discriminate.")

    if not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)
    if not os.path.exists(FLAGS.output_dir + 'images/'):
        os.makedirs(FLAGS.output_dir + 'images/')

    if FLAGS.input_data_pattern is "":
        raise ValueError("'input_data_pattern' was not specified. "
                         "Unable to continue with generate_and_discriminate.")

    if not (0 <= FLAGS.num_generate <= FLAGS.num_total_images):
        raise ValueError("'num_generate' should be between "
                         "[0, num_total_images]. Unable to continue "
                         "with generate_and_discriminate.")

    # Make separate graphs in order to load two different models.
    G_graph = tf.Graph()
    D_graph = tf.Graph()

    images = generate(reader, G_graph, FLAGS.G_train_dir,
                      FLAGS.input_data_pattern, FLAGS.output_dir,
                      FLAGS.num_generate, FLAGS.num_total_images)

    discriminate(images, D_graph, FLAGS.D_train_dir, FLAGS.output_dir)
コード例 #2
0
def main(unused_argv):
    logging.set_verbosity(tf.logging.INFO)

    if FLAGS.use_mnist:
        reader = readers.MnistReader()
    else:
        reader = readers.FaceReader()

    if FLAGS.output_file is "":
        raise ValueError("'output_file' was not specified. "
                         "Unable to continue with linear_interpolation.")

    linear_interpolation(reader, FLAGS.train_dir, FLAGS.output_file)
コード例 #3
0
ファイル: inference.py プロジェクト: xhae/tutorial_mnist
def main(unused_argv):
    logging.set_verbosity(tf.logging.INFO)

    reader = readers.MnistReader()

    if FLAGS.output_file is "":
        raise ValueError("'output_file' was not specified. "
                         "Unable to continue with inference.")

    if FLAGS.input_data_pattern is "":
        raise ValueError("'input_data_pattern' was not specified. "
                         "Unable to continue with inference.")

    inference(reader, FLAGS.train_dir, FLAGS.input_data_pattern,
              FLAGS.output_file, 8192)
コード例 #4
0
def evaluate():
    tf.set_random_seed(0)  # for reproducibility
    with tf.Graph().as_default():
        if FLAGS.use_mnist:
            reader = readers.MnistReader()
        else:
            reader = readers.FaceReader()

        generator_model = find_class_by_name(FLAGS.generator_model, [models])()
        discriminator_model = find_class_by_name(FLAGS.discriminator_model,
                                                 [models])()
        label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

        if FLAGS.eval_data_pattern is "":
            raise IOError("'eval_data_pattern' was not specified. " +
                          "Nothing to evaluate.")

        build_graph(reader=reader,
                    generator_model=generator_model,
                    discriminator_model=discriminator_model,
                    eval_data_pattern=FLAGS.eval_data_pattern,
                    label_loss_fn=label_loss_fn,
                    num_readers=FLAGS.num_readers,
                    batch_size=FLAGS.batch_size)
        logging.info("built evaluation graph")
        p_fake_batch = tf.get_collection("p_for_fake")[0]
        p_real_batch = tf.get_collection("p_for_data")[0]
        G_loss = tf.get_collection("G_loss")[0]
        D_loss = tf.get_collection("D_loss")[0]
        noise_input = tf.get_collection("noise_input_placeholder")[0]
        summary_op = tf.get_collection("summary_op")[0]

        saver = tf.train.Saver(tf.global_variables())
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=tf.get_default_graph())

        evl_metrics = eval_util.EvaluationMetrics()

        last_global_step_val = -1
        while True:
            last_global_step_val = evaluation_loop(p_fake_batch, p_real_batch,
                                                   G_loss, D_loss, noise_input,
                                                   summary_op, saver,
                                                   summary_writer, evl_metrics,
                                                   last_global_step_val,
                                                   FLAGS.batch_size)
            if FLAGS.run_once:
                break
コード例 #5
0
ファイル: eval.py プロジェクト: xhae/tutorial_mnist
def evaluate():
    tf.set_random_seed(0)  # for reproducibility
    with tf.Graph().as_default():
        reader = readers.MnistReader()

        model = find_class_by_name(FLAGS.model, [mnist_models])()
        label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()

        if FLAGS.eval_data_pattern is "":
            raise IOError("'eval_data_pattern' was not specified. " +
                          "Nothing to evaluate.")

        build_graph(reader=reader,
                    model=model,
                    eval_data_pattern=FLAGS.eval_data_pattern,
                    label_loss_fn=label_loss_fn,
                    num_readers=FLAGS.num_readers,
                    batch_size=FLAGS.batch_size)
        logging.info("built evaluation graph")
        prediction_batch = tf.get_collection("predictions")[0]
        label_batch = tf.get_collection("labels")[0]
        loss = tf.get_collection("loss")[0]
        summary_op = tf.get_collection("summary_op")[0]

        saver = tf.train.Saver(tf.global_variables())
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=tf.get_default_graph())

        evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, 2)

        last_global_step_val = -1
        while True:
            last_global_step_val = evaluation_loop(prediction_batch,
                                                   label_batch, loss,
                                                   summary_op, saver,
                                                   summary_writer, evl_metrics,
                                                   last_global_step_val)
            if FLAGS.run_once:
                break
コード例 #6
0
ファイル: train.py プロジェクト: xhae/tutorial_mnist
def get_reader():
    reader = readers.MnistReader()
    return reader
コード例 #7
0
def get_reader():
    if FLAGS.use_mnist:
        reader = readers.MnistReader()
    else:
        reader = readers.FaceReader()
    return reader