示例#1
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10, output_size=28, c_dim=1,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir)
        else:
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, output_size=FLAGS.output_size, c_dim=FLAGS.c_dim,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)
            print(dcgan.evaluate(6400))

        if FLAGS.visualize:
            # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
            #                               [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
            #                               [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
            #                               [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
            #                               [dcgan.h4_w, dcgan.h4_b, None])

            # Below is codes for visualization
            OPTION = 2
            visualize(sess, dcgan, FLAGS, OPTION)
示例#2
0
def main(_):
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.result_dir):
        os.makedirs(FLAGS.result_dir)

    dataset = FlowersData(subset=FLAGS.subset)
    assert dataset.data_files()

    dcgan = DCGAN(z_dim=200, dataset=dataset)
    dcgan.evaluate()

    if FLAGS.visualize:
        # Below is codes for visualization
        OPTION = 2
        visualize(dcgan, FLAGS, OPTION)
示例#3
0
def main(_):

    if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
    if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height
    FLAGS.out_name = '{}-{}x{}-bn{}'.format(FLAGS.dataset, FLAGS.output_height,
                                            FLAGS.output_width,
                                            FLAGS.batch_size)
    # FLAGS.out_name = '{}-{}x{}'.format(FLAGS.dataset, FLAGS.output_height,FLAGS.output_width)

    FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
    FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
    FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True
    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess,
                      input_width=FLAGS.input_width,
                      input_height=FLAGS.input_height,
                      output_width=FLAGS.output_width,
                      output_height=FLAGS.output_height,
                      batch_size=FLAGS.batch_size,
                      sample_num=FLAGS.batch_size,
                      z_dim=FLAGS.z_dim,
                      dataset_name=FLAGS.dataset,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir,
                      data_dir=FLAGS.data_dir,
                      out_dir=FLAGS.out_dir,
                      max_to_keep=FLAGS.max_to_keep)
        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            # dcgan.genPics(FLAGS)
            dcgan.evaluate(FLAGS)