コード例 #1
0
ファイル: main.py プロジェクト: johndpope/dcgan-tfslim
def main(_):
    pp.pprint(FLAGS.__flags)

    # training/inference
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS)

        # path checks
        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(
                os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))

        # load checkpoint if found
        if dcgan.checkpoint_exists():
            print("Loading checkpoints...")
            if dcgan.load():
                print "success!"
            else:
                raise IOError("Could not read checkpoints from {0}!".format(
                    FLAGS.checkpoint_dir))
        else:
            print "No checkpoints found. Training from scratch."
            dcgan.load()

        # train DCGAN
        if FLAGS.train:
            train(dcgan)
        else:
            dcgan.load()
コード例 #2
0
ファイル: main.py プロジェクト: SmartAI/GAN
def main(_):
    pp.pprint(FLAGS.__flags)
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS)

        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(
                os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir()))
        if not os.path.exists(
                os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))

        if dcgan.checkpoint_exists():
            print "Loading checkpoints"
            if dcgan.load():
                print "Success"
            else:
                raise IOError("Could not read checkpoints from {}".format(
                    FLAGS.checkpoint_dir))
        else:
            if not FLAGS.train:
                raise IOError("No checkpoints found")
            print "No checkpoints found. Training from scratch"
            dcgan.load()

        if FLAGS.train:
            train(dcgan)

        print "Generating samples..."
        inference.sample_images(dcgan)
        inference.visualize_z(dcgan)
コード例 #3
0
ファイル: main.py プロジェクト: mqtlam/dcgan-tfslim
def main(_):
    pp.pprint(FLAGS.__flags)

    # training/inference
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS)

        # path checks
        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(
                os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))
        if not os.path.exists(
                os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())):
            os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir()))

        # load checkpoint if found
        if dcgan.checkpoint_exists():
            print("Loading checkpoints...")
            if dcgan.load():
                print "success!"
            else:
                raise IOError("Could not read checkpoints from {0}!".format(
                    FLAGS.checkpoint_dir))
        else:
            if not FLAGS.train:
                raise IOError("No checkpoints found but need for sampling!")
            print "No checkpoints found. Training from scratch."
            dcgan.load()

        # train DCGAN
        if FLAGS.train:
            train(dcgan)

        # inference/visualization code goes here
        print "Generating samples..."
        inference.sample_images(dcgan)
        print "Generating visualizations of z..."
        inference.visualize_z(dcgan)