コード例 #1
0
ファイル: main.py プロジェクト: yanshui177/Iterative-GAN
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    current_time = time.strftime('%Y-%m-%d_%H_%M_%S',
                                 time.localtime(time.time()))
    FLAGS.OutputDirName += "/" + current_time
    if not os.path.exists(FLAGS.OutputDirName):
        os.makedirs(FLAGS.OutputDirName)

    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess, FLAGS)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")

    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])

    # Below is codes for visualization
        OPTION = 4
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #2
0
ファイル: main.py プロジェクト: ulzee/megapix-gan
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists('results/%s' % FLAGS.sample_dir):
        os.makedirs('results/%s' % FLAGS.sample_dir)

    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(
            sess,
            input_width=FLAGS.input_width,
            input_height=FLAGS.input_height,
            output_width=FLAGS.output_width,
            output_height=FLAGS.output_height,
            batch_size=FLAGS.batch_size,
            sample_num=FLAGS.sample_num,
            z_dim=FLAGS.z_dim,
            # gf_dim=FLAGS.gf_dim,
            dataset_name=FLAGS.dataset,
            input_fname_pattern=FLAGS.input_fname_pattern,
            crop=FLAGS.crop,
            checkpoint_dir=FLAGS.checkpoint_dir,
            sample_dir=FLAGS.sample_dir,
            c_dim=1,
            grow=FLAGS.grow)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")
コード例 #3
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          input_width=FLAGS.input_width,
                          input_height=FLAGS.input_height,
                          output_width=FLAGS.output_width,
                          output_height=FLAGS.output_height,
                          batch_size=FLAGS.batch_size,
                          sample_num=FLAGS.batch_size,
                          y_dim=10,
                          dataset_name=FLAGS.dataset,
                          input_fname_pattern=FLAGS.input_fname_pattern,
                          crop=FLAGS.crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir)
        else:
            dcgan = DCGAN(sess,
                          input_width=FLAGS.input_width,
                          input_height=FLAGS.input_height,
                          output_width=FLAGS.output_width,
                          output_height=FLAGS.output_height,
                          batch_size=FLAGS.batch_size,
                          sample_num=FLAGS.batch_size,
                          y_dim=6,
                          dataset_name=FLAGS.dataset,
                          input_fname_pattern=FLAGS.input_fname_pattern,
                          crop=FLAGS.crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")

        # Below is codes for visualization
        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #4
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess,
                      input_width=FLAGS.input_width,
                      input_height=FLAGS.input_height,
                      output_width=FLAGS.output_width,
                      output_height=FLAGS.output_height,
                      batch_size=FLAGS.batch_size,
                      sample_num=FLAGS.batch_size,
                      z_dim=FLAGS.generate_test_images,
                      dataset_name=FLAGS.dataset,
                      input_fname_pattern=FLAGS.input_fname_pattern,
                      crop=FLAGS.crop,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir,
                      data_dir=FLAGS.data_dir)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")

        # visualization code run both in train/test mode.
        visualize(sess, dcgan, FLAGS)
コード例 #5
0
def main(argv):

    # turn off log message
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.FATAL)

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config(index=FLAGS.gpu_index)) as sess:

        # set network
        kwars = {
            'sess': sess,
            'noise_dim': FLAGS.noise_dim,
            'image_size': FLAGS.image_size,
            'generator_layer': generator_layer,
            'discriminator_layer': discriminator_layer,
            'is_training': False
        }

        Model = DCGAN(**kwars)

        # initialize
        sess.run(init_op)

        # test
        Model.restore_model(
            os.path.join(FLAGS.indir, 'model',
                         'model_{}'.format(FLAGS.model_index)))

        noise = np.random.uniform(-1.,
                                  1.,
                                  size=[FLAGS.batch_size, FLAGS.noise_dim])
        samples = Model.gen_sample(noise)

        samples = samples[:, :, :, 0]
        m = utils.montage(samples)
        gen_img = m
        plt.axis('off')
        plt.imshow(gen_img, cmap='gray')
        plt.show()
コード例 #6
0
ファイル: main.py プロジェクト: dougneal/tmad-tensorflow
def tf_main(_):
    # Hoping there's a nicer way of factoring this as 'flags.FLAGS.__flags' is bloody horrible
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)

    if not os.path.exists(FLAGS.logs_dir):
        os.makedirs(FLAGS.logs_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True  # no clue

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(
            sess,
            width=FLAGS.width,
            height=FLAGS.height,
            checkpoint_dir=FLAGS.checkpoint_dir,
            logs_dir=FLAGS.logs_dir,
            batch_size=FLAGS.batch_size,
            sample_num=FLAGS.batch_size,
        )

        # this is the show_all_variables() function in upstream (model.py)
        model_vars = tf.trainable_variables()
        tf.contrib.slim.model_analyzer.analyze_vars(model_vars,
                                                    print_info=True)

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("Model needs training first")

        z_sample = numpy.random.uniform(-0.5,
                                        0.5,
                                        size=(FLAGS.batch_size, dcgan.z_dim))
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
        export_images_to_s3(samples,
                            key_prefix="{0}/samples/".format(
                                dcgan.session_timestamp))
コード例 #7
0
def main(_):

    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    # Do not take all memory
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.30)
    # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        # w/ y label
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          image_size=FLAGS.image_size,
                          batch_size=FLAGS.batch_size,
                          y_dim=10,
                          output_size=28,
                          c_dim=1,
                          dataset_name=FLAGS.dataset,
                          checkpoint_dir=FLAGS.checkpoint_dir)
        # w/o y label
        else:
            if FLAGS.dataset == 'cityscapes':
                print 'Select CITYSCAPES'
                mask_dir = CITYSCAPES_mask_dir
                syn_dir = CITYSCAPES_syn_dir_2
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 192, 512, False
                FLAGS.dataset_dir = CITYSCAPES_dir
            elif FLAGS.dataset == 'inria':
                print 'Select INRIAPerson'
                FLAGS.output_size_h, FLAGS.output_size_w, FLAGS.is_crop = 160, 96, False
                FLAGS.dataset_dir = INRIA_dir

            discriminator = Discriminator(sess,
                                          batch_size=FLAGS.batch_size,
                                          output_size_h=FLAGS.output_size_h,
                                          output_size_w=FLAGS.output_size_w,
                                          c_dim=FLAGS.c_dim,
                                          dataset_name=FLAGS.dataset,
                                          checkpoint_dir=FLAGS.checkpoint_dir,
                                          dataset_dir=FLAGS.dataset_dir)

        if FLAGS.mode == 'test':
            print('Testing!')
            discriminator.test(FLAGS, syn_dir)
        elif FLAGS.mode == 'train':
            print('Train!')
            discriminator.train(FLAGS, syn_dir)
        elif FLAGS.mode == 'complete':
            print('Complete!')
コード例 #8
0
def main():

    config = get_config_from_json('config.json')
    # create an instance of the model
    model = DCGAN(config)
    # create experiments instance
    experiments = Experiments(config, model)
    # create trainer instance
    trainer = Trainer(config, model, experiments)
    # train the model
    trainer.train()
コード例 #9
0
def main(_):

    assert (os.path.exists(FLAGS.checkpoint_dir))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        dcgan = DCGAN(sess,
                      input_width=FLAGS.input_width,
                      input_height=FLAGS.input_height,
                      output_width=FLAGS.output_width,
                      output_height=FLAGS.output_height,
                      batch_size=FLAGS.batch_size,
                      sample_num=FLAGS.batch_size,
                      z_dim=FLAGS.latent_dim,
                      dataset_name=FLAGS.dataset,
                      input_fname_pattern=FLAGS.input_fname_pattern,
                      crop=FLAGS.crop,
                      checkpoint_dir=FLAGS.checkpoint_dir)

        dcgan.detect_anomaly(FLAGS)
コード例 #10
0
def main(_):
    # Make sure that the directories to save data have been created
    check_dirs([config.chk_dir, config.smp_dir])

    # Set up tensorflow to only use the GPU resources it needs, and to grow when more is necessary
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        # Create and train the GAN, then visualize the results
        dcgan = DCGAN(sess, config)
        visualize(sess, dcgan, config, option=1)
コード例 #11
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        dcgan = DCGAN(sess,
                      image_size=FLAGS.image_size,
                      output_size=FLAGS.output_size,
                      batch_size=FLAGS.batch_size,
                      sample_size=FLAGS.sample_size)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        if FLAGS.visualize:
            # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
            #                               [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
            #                               [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
            #                               [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
            #                               [dcgan.h4_w, dcgan.h4_b, None])

            # Below is codes for visualization
            OPTION = 2
            visualize(sess, dcgan, FLAGS, OPTION)
コード例 #12
0
ファイル: main.py プロジェクト: leconteur/DCGAN-tensorflow
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
        else:
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        to_json("./web/js/gen_layers.js", dcgan.h0_w, dcgan.h1_w, dcgan.h2_w, dcgan.h3_w, dcgan.h4_w)

        z_sample = np.random.uniform(-1, 1, size=(FLAGS.batch_size, dcgan.z_dim))

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
        save_images(samples, [8, 8], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime()))
コード例 #13
0
def main(_):


  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)


  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)
    else:
      if not dcgan.load(FLAGS.checkpoint_dir)[0]:
        raise Exception("[No training data found to test on] ")
コード例 #14
0
ファイル: main.py プロジェクト: AmitShah/DCGAN-tensorflow
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
        else:
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
                                      [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
                                      [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
                                      [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
                                      [dcgan.h4_w, dcgan.h4_b, None])

        # Below is codes for visualization
        OPTION = 2
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #15
0
ファイル: main.py プロジェクト: awp4211/DLAdvance
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess,
                      input_width=FLAGS.input_width,
                      input_height=FLAGS.input_height,
                      output_width=FLAGS.output_width,
                      output_height=FLAGS.output_height,
                      batch_size=FLAGS.batch_size,
                      sample_num=FLAGS.batch_size,
                      y_dim=10,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)

        show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")

        # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
        #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
        #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
        #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
        #                 [dcgan.h4_w, dcgan.h4_b, None])

        # Below is codes for visualization
        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #16
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height

    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    assert (os.path.exists(FLAGS.checkpoint_dir))

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess,
                      input_width=FLAGS.input_width,
                      input_height=FLAGS.input_height,
                      output_width=FLAGS.output_width,
                      output_height=FLAGS.output_height,
                      batch_size=FLAGS.batch_size,
                      dataset_name=FLAGS.dataset,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      z_dim=FLAGS.z_dim,
                      lam=FLAGS.lam)
        #dcgan.load(FLAGS.checkpoint_dir):
        dcgan.complete(FLAGS)

        show_all_variables()

        # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
        #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
        #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
        #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
        #                 [dcgan.h4_w, dcgan.h4_b, None])

        # Below is codes for visualization
        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #17
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.summary_dir):
        os.makedirs(FLAGS.summary_dir)
    runs = sorted(map(int, next(os.walk(FLAGS.summary_dir))[1]))
    if len(runs) == 0:
        run_nr = 0
    else:
        run_nr = runs[-1] + 1
    run_folder = str(run_nr).zfill(3)

    FLAGS.summary_dir = os.path.join(FLAGS.summary_dir, run_folder)
    FLAGS.checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, run_folder)
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        dcgan = DCGAN(sess, batch_size=FLAGS.batch_size)
        if not os.path.exists(FLAGS.checkpoint_dir):
            os.makedirs(FLAGS.checkpoint_dir)
        if not os.path.exists(FLAGS.summary_dir):
            os.makedirs(FLAGS.summary_dir)

        dcgan.train(FLAGS, run_folder)
コード例 #18
0
def main(_):

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    file_name = time.strftime("%Y_%m_%d_%H%M", time.localtime())
    if FLAGS.unreg:
        file_name += "_unreg_dcgan"
    else:
        file_name += "_regularized_dcgan_" + str(FLAGS.gamma) + "gamma"
    if FLAGS.annealing:
        file_name += "_annealing_" + str(FLAGS.decay_factor) + "decayfactor"
    if FLAGS.rmsprop:
        file_name += "_rmsprop"
    else:
        file_name += "_adam"
    file_name += "_" + str(FLAGS.disc_update_steps) + "dsteps"
    file_name += "_" + str(FLAGS.disc_learning_rate) + "dlnr"
    file_name += "_" + str(FLAGS.gen_learning_rate) + "glnr"
    file_name += "_" + str(FLAGS.epochs) + "epochs"
    file_name += "_" + str(FLAGS.dataset)

    log_dir = os.path.abspath(os.path.join(FLAGS.root_dir, file_name))

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    pp.pprint(flags.FLAGS.__flags)

    with tf.Session() as sess:

        dcgan = DCGAN(sess, log_dir, FLAGS)

        show_all_variables()

        if FLAGS.checkpoint_dir is not "None":
            if not dcgan.load_ckpt()[0]:
                raise Exception("[!] ERROR: provide valid checkpoint_dir")
        else:
            starttime = time.time()

            dcgan.train(FLAGS)

            endtime = time.time()
            print('Total Train Time: {:.2f}'.format(endtime - starttime))

        dcgan.generate(FLAGS, option=1)

    file = open(os.path.join(log_dir, "flags.json"), 'a')
    json.dump(vars(FLAGS), file)
    file.close()
コード例 #19
0
ファイル: main.py プロジェクト: epiasini/DCGAN-tensorflow
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          image_size=FLAGS.image_size,
                          batch_size=FLAGS.batch_size,
                          y_dim=10,
                          output_size=28,
                          c_dim=1,
                          dataset_name=FLAGS.dataset,
                          is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir)
        else:
            dcgan = DCGAN(sess,
                          image_size=FLAGS.image_size,
                          batch_size=FLAGS.batch_size,
                          output_size=FLAGS.output_size,
                          c_dim=FLAGS.c_dim,
                          dataset_name=FLAGS.dataset,
                          is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          log_dir=FLAGS.log_dir,
                          sample_dir=FLAGS.sample_dir,
                          vgg_reg=FLAGS.vgg_reg)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        if FLAGS.visualize:
            to_json("./web/js/layers.js",
                    [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
                    [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
                    [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
                    [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
                    [dcgan.h4_w, dcgan.h4_b, None])

            # Below is codes for visualization
            OPTION = 2
            visualize(sess, dcgan, FLAGS, OPTION)
コード例 #20
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        dcgan = DCGAN(sess, 
                      dataset=FLAGS.dataset,
                      batch_size=FLAGS.batch_size,
                      output_size=FLAGS.output_size,
                      c_dim=FLAGS.c_dim)

        if FLAGS.is_train:
            if FLAGS.preload_data == True:
                data = get_data_arr(FLAGS)
            else:
                data = glob(os.path.join('./data', FLAGS.dataset, '*.jpg'))
            train.train_wasserstein(sess, dcgan, data, FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)
コード例 #21
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess,
                      c_dim=3,
                      input_width=FLAGS.input_width,
                      input_height=FLAGS.input_height,
                      output_width=FLAGS.output_width,
                      output_height=FLAGS.output_height,
                      batch_size=FLAGS.batch_size,
                      sample_num=FLAGS.batch_size,
                      dataset_name=FLAGS.dataset,
                      crop=FLAGS.crop,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir,
                      adversarial_path=FLAGS.adversarial_path,
                      ground_truth_path=FLAGS.ground_truth_path,
                      test_path=FLAGS.test_path,
                      save_path=FLAGS.save_path)

        dcgan.load(FLAGS.checkpoint_dir)[0]
        #show_all_variables()

        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")

        # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
        #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
        #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
        #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
        #                 [dcgan.h4_w, dcgan.h4_b, None])

        # Below is codes for visualization
        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #22
0
def main(_):
    #  gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    #  run_config = tf.ConfigProto(gpu_options=gpu_options)
    #  run_config.gpu_options.allow_growth = True
    #  with tf.Session(config=run_config) as sess:
    with tf.Session() as sess:
        dcgan = DCGAN(sess, FLAGS.input_height, FLAGS.input_width)
        dcgan.build_model()
        dcgan.train(FLAGS)
コード例 #23
0
ファイル: train_patents.py プロジェクト: 255BITS/improved-gan
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session(config=tf.ConfigProto(
              allow_soft_placement=True, log_device_placement=False)) as sess:
        if FLAGS.dataset == 'mnist':
            assert False
        dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                    sample_size = 64,
                    z_dim = 8192,
                    d_label_smooth = .25,
                    generator_target_prob = .75 / 2.,
                    out_stddev = .075,
                    out_init_b = - .45,
                    image_shape=[FLAGS.image_width, FLAGS.image_width, 3],
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir,
                    sample_dir=FLAGS.sample_dir,
                    generator=Generator(),
                    train_func=train, discriminator_func=discriminator,
                    build_model_func=build_model, config=FLAGS,
                    devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"]
                    )

        if FLAGS.is_train:
            print "TRAINING"
            dcgan.train(FLAGS)
            print "DONE TRAINING"
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        OPTION = 2
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #24
0
def main(_):
    #入力画像のサイズ補完
    pp.pprint(flags.FLAGS.__flags)

    #入力画像のサイズ補完
    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    #出力先ディレクトリの作成
    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    #gpuの使用設定
    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    #tfセッションの作成、モデルの生成
    with tf.Session(config=run_config) as sess:
        dcgan = DCGAN(sess,
                      input_width=FLAGS.input_width,
                      input_height=FLAGS.input_height,
                      output_width=FLAGS.output_width,
                      output_height=FLAGS.output_height,
                      batch_size=FLAGS.batch_size,
                      sample_num=FLAGS.batch_size,
                      z_dim=FLAGS.generate_test_images,
                      dataset_name=FLAGS.dataset,
                      input_fname_pattern=FLAGS.input_fname_pattern,
                      crop=FLAGS.crop,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir,
                      data_dir=FLAGS.data_dir)

        show_all_variables()

        #trainが有効ならトレーニング 無効ならパラメータ読み込み
        if FLAGS.train:
            dcgan.train(FLAGS)
        else:
            if not dcgan.load(FLAGS.checkpoint_dir)[0]:
                raise Exception("[!] Train a model first, then run test mode")

        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION)

        if FLAGS.vec_t:
            dcgan.train_vec(FLAGS)
コード例 #25
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    "%s" % FLAGS.checkpoint_dir
    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          image_size=FLAGS.image_size,
                          batch_size=FLAGS.batch_size,
                          y_dim=10,
                          output_size=28,
                          c_dim=1,
                          dataset_name=FLAGS.dataset,
                          is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir)
        else:
            dcgan = DCGAN(sess,
                          image_size=FLAGS.image_size,
                          batch_size=FLAGS.batch_size,
                          output_size=FLAGS.output_size,
                          c_dim=FLAGS.c_dim,
                          dataset_name=FLAGS.dataset,
                          is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
        #                               [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
        #                               [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
        #                               [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
        #                               [dcgan.h4_w, dcgan.h4_b, None])

        # Below is codes for visualization
        OPTION = 5
        file_writer = tf.summary.FileWriter('path/to/logs', sess.graph)
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #26
0
ファイル: train.py プロジェクト: baoqianyue/MachineLearning
def main():
    # 创建输出文件夹
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    database = provider.DBreader(src_dir,
                                 batch_size,
                                 resize=[64, 64, 3],
                                 labeled=False)
    sess = tf.Session()
    model = DCGAN(sess, batch_size)
    sess.run(tf.global_variables_initializer())

    total_batch = database.total_batch
    save_img_shape = 14 * 14
    noise_z = np.random.normal(size=(save_img_shape, n_noise))

    loss_d = 0.0
    loss_g = 0.0
    for epoch in range(total_epoch):
        for step in range(total_batch):
            batch_imgs = database.next_batch()
            batch_imgs = batch_imgs / 127.5 - 1  # 规划到-1 ~ 1
            noise_g = np.random.normal(size=(batch_size, n_noise))
            noise_d = np.random.normal(size=(batch_size, n_noise))

            loss_d = model.train_discriminator(batch_imgs, noise_d)
            loss_g = model.train_generator(noise_g)

            if epoch == 0 and step < 200:
                adventage = 2
            else:
                adventage = 1

            if step % adventage == 0:
                loss_d = model.train_discriminator(
                    batch_imgs,
                    noise_d)  # Train Discriminator and get the loss value
            loss_g = model.train_generator(
                noise_g)  # Train Generator and get the loss value

            print('Epoch: [', epoch, '/', total_epoch, '], ', 'Step: [', step,
                  '/', total_batch, '], D_loss: ', loss_d, ', G_loss: ',
                  loss_g)

            if step == 0 or (step + 1) % 10 == 0:
                generated_samples = model.generator_sample(
                    noise_z, batch_size=save_img_shape)
                savepath = output_dir + '/output_' + 'EP' + str(epoch).zfill(
                    3) + "_Batch" + str(step).zfill(6) + '.jpg'
                save_img(generated_samples, (14, 14), save_path=savepath)
コード例 #27
0
ファイル: main.py プロジェクト: ccyyatnet/DCGAN
def main(_):

    pp.pprint(FLAGS.__flags)

    FLAGS.is_grayscale = (FLAGS.c_dim == 1)
    FLAGS.sample_dir = FLAGS.result_dir + 'samples/' + FLAGS.dataset + '_' + FLAGS.dir_tag
    FLAGS.checkpoint_dir = FLAGS.result_dir + 'checkpoint/' + FLAGS.dataset + '_' + FLAGS.dir_tag

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:

        if FLAGS.dataset == 'mnist':
            FLAGS.image_size = 32
            FLAGS.c_dim = 1

        with tf.device(FLAGS.devices):

            dcgan = DCGAN(sess, config=FLAGS)

            if FLAGS.is_train:
                dcgan.train(FLAGS)
            else:
                if dcgan.load(FLAGS):
                    print " [*] Load SUCCESS"
                    if FLAGS.random_z:
                        print " [*] Test RANDOM Z"
                        dcgan.test_fix(FLAGS)
                    else:
                        print " [*] Test Z"
                        dcgan.test_z(FLAGS)
                else:
                    print " [!] Load failed..."
コード例 #28
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, batch_size=FLAGS.batch_size, y_dim=10)
        else:
            dcgan = DCGAN(sess, batch_size=FLAGS.batch_size)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        z_sample = np.random.uniform(-1, 1, size=(FLAGS.batch_size, dcgan.z_dim))

        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
        save_images(samples, [14, 14], './samples/test.png')
コード例 #29
0
ファイル: main.py プロジェクト: 255BITS/DCGAN-tensorflow
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          wav_size=FLAGS.wav_size,
                          batch_size=FLAGS.batch_size,
                          y_dim=10,
                          dataset_name=FLAGS.dataset,
                          is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir)
        else:
            dcgan = DCGAN(sess,
                          wav_size=FLAGS.wav_size,
                          batch_size=FLAGS.batch_size,
                          dataset_name=FLAGS.dataset,
                          is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir)

        if FLAGS.is_train:
            print("TRAIN TIME")
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        to_json(
            "./web/js/layers.js",
            [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
            [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
            [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
            [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
        )  # TODO FIXME [dcgan.h4_w, dcgan.h4_b, None])

        # Below is codes for visualization
        OPTION = 1
        visualize(sess, dcgan, FLAGS, OPTION)
コード例 #30
0
ファイル: main.py プロジェクト: JohannesHeidecke/UPC-MAI-CI
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess,
                          image_size=FLAGS.image_size,
                          batch_size=FLAGS.batch_size,
                          y_dim=10,
                          output_size=28,
                          c_dim=1,
                          dataset_name=FLAGS.dataset,
                          is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir,
                          d=FLAGS.d,
                          k=FLAGS.k)
        else:
            dcgan = DCGAN(sess,
                          image_size=FLAGS.image_size,
                          batch_size=FLAGS.batch_size,
                          output_size=FLAGS.output_size,
                          c_dim=FLAGS.c_dim,
                          dataset_name=FLAGS.dataset,
                          is_crop=FLAGS.is_crop,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          sample_dir=FLAGS.sample_dir)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)


        OPTION = 0
        visualize(sess, dcgan, FLAGS, OPTION, FLAGS.d, FLAGS.sample_dir)
コード例 #31
0
ファイル: main.py プロジェクト: johndpope/easygen
def run(checkpoint_dir = 'checkpoints', batch_size = 64, input_height = 108, input_width = None, output_height = 64, output_width = None, dataset = 'celebA', input_fname_pattern = '*.jpg', output_dir = 'output', sample_dir = 'samples', crop=True):
  #pp.pprint(flags.FLAGS.__flags)

  if input_width is None:
    input_width = input_height
  if output_width is None:
    output_width = output_height

  #if not os.path.exists(checkpoint_dir):
  #  os.makedirs(checkpoint_dir)
  #if not os.path.exists(output_dir):
  #  os.makedirs(output_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    dcgan = DCGAN(
        sess,
        input_width=input_width,
        input_height=input_height,
        output_width=output_width,
        output_height=output_height,
        batch_size=batch_size,
        sample_num=batch_size,
        dataset_name=dataset,
        input_fname_pattern=input_fname_pattern,
        crop=crop,
        checkpoint_dir=checkpoint_dir,
        sample_dir=sample_dir,
        output_dir=output_dir)

    show_all_variables()

    try:
      tf.global_variables_initializer().run()
    except:
      tf.initialize_all_variables().run()

    # Below is code for visualization
    visualize(sess, dcgan, batch_size = batch_size, input_height = input_height, input_width = input_width, output_dir = output_dir)
コード例 #32
0
def main(argv):
    mnist = input_data.read_data_sets('data/mnist', one_hot=True)
    X = tf.placeholder(tf.float32, shape=[None, 784])
    z = tf.placeholder(tf.float32, shape=[None, 100])

    batch_size = 64
    model = DCGAN(FLAGS.height, FLAGS.width, X, z, batch_size)

    D_loss = model.D_loss
    G_loss = model.G_loss
    G_solver = model.G_solver
    D_solver = model.D_solver
    merged = model.merged

    with tf.Session() as sess:
        saver = tf.train.Saver()
        # if os.path.exists('./models'):
        #   saver.restore(sess, './models/dcgan.ckpt-1900')
        writer = tf.summary.FileWriter('./models', sess.graph)
        sess.run(tf.global_variables_initializer())

        for e in range(0, FLAGS.epoch):

            x, _ = mnist.train.next_batch(batch_size)
            rand = np.random.uniform(0., 1., size=[batch_size, 100])
            _, summ, dl, gl = sess.run([D_solver, merged, D_loss, G_loss], {
                X: x,
                z: rand
            })
            rand = np.random.uniform(0., 1., size=[batch_size, 100])
            _ = sess.run([G_solver], {z: rand})
            rand = np.random.uniform(0., 1., size=[batch_size, 100])
            _ = sess.run([G_solver], {z: rand})

            writer.add_summary(summ, global_step=e)
            if e % 100 == 0:
                saver.save(sess, './models/dcgan.ckpt', global_step=e)

        writer.close()
コード例 #33
0
ファイル: main.py プロジェクト: minoring/dcgan
def run(FLAGS):
  run_config = tf.compat.v1.ConfigProto()
  run_config.gpu_options.allow_growth = True

  with tf.compat.v1.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(sess,
                    input_height=FLAGS.input_height,
                    input_width=FLAGS.input_width,
                    output_width=FLAGS.output_width,
                    output_height=FLAGS.output_height,
                    batch_size=FLAGS.batch_size,
                    sample_num=FLAGS.batch_size,
                    z_dim=FLAGS.z_dim,
                    dataset_name=FLAGS.dataset,
                    input_fname_pattern=FLAGS.input_fname_pattern,
                    crop=FLAGS.crop,
                    checkpoint_dir=FLAGS.checkpoint_dir,
                    sample_dir=FLAGS.sample_dir,
                    data_dir=FLAGS.data_dir,
                    out_dir=FLAGS.out_dir,
                    max_to_keep=FLAGS.max_to_keep)
コード例 #34
0
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--eps', type=float, default=1e-8)
parser.add_argument('--hmcBeta', type=float, default=0.2)
parser.add_argument('--hmcEps', type=float, default=0.001)
parser.add_argument('--hmcL', type=int, default=100)
parser.add_argument('--hmcAnneal', type=float, default=1)
parser.add_argument('--nIter', type=int, default=1000)
parser.add_argument('--imgSize', type=int, default=64)
parser.add_argument('--lam', type=float, default=0.1)
parser.add_argument('--checkpointDir', type=str, default='checkpoint')
parser.add_argument('--outDir', type=str, default='completions')
parser.add_argument('--outInterval', type=int, default=50)
parser.add_argument('--maskType', type=str,
                    choices=['random', 'center', 'left', 'full', 'grid', 'lowres'],
                    default='center')
parser.add_argument('--centerScale', type=float, default=0.25)
parser.add_argument('imgs', type=str, nargs='+')

args = parser.parse_args()

assert(os.path.exists(args.checkpointDir))

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    dcgan = DCGAN(sess, image_size=args.imgSize,
                  batch_size=min(64, len(args.imgs)),
                  checkpoint_dir=args.checkpointDir, lam=args.lam)
    dcgan.complete(args)
コード例 #35
0
ファイル: trainer.py プロジェクト: MasazI/dcgan
def train():
    with tf.Graph().as_default():
        # data
        dataset = traindataset.DataSet(DATA_DIR, SAMPLE_SIZE)
        # tfrecords inputs
        images, labels_t = dataset.csv_inputs(CSVFILE)

        z = tf.placeholder(tf.float32, [None, Z_DIM], name='z')

        dcgan = DCGAN("test", "./checkpoint")
        images_inf, logits, logits_, G_sum, z_sum, d_sum, d__sum = dcgan.inference(images, z)
        d_loss_fake, d_loss_real, d_loss_real_sum, d_loss_fake_sum, d_loss_sum, g_loss_sum, d_loss, g_loss = dcgan.loss(logits, logits_)

        # trainable variables
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]

        # train operations
        d_optim = D_train_op(d_loss, d_vars)
        g_optim = G_train_op(g_loss, g_vars)

        # generate images
        generate_images = dcgan.generate_images(z, 4, 4)

        # summary
        g_sum = tf.merge_summary([z_sum, d__sum, G_sum, d_loss_fake_sum, g_loss_sum])
        d_sum = tf.merge_summary([z_sum, d_sum, d_loss_real_sum, d_loss_sum])
        #summary_op = tf.merge_all_summaries()

        # init operation
        init_op = tf.initialize_all_variables()

        # Session
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=LOG_DEVICE_PLACEMENT))
        writer = tf.train.SummaryWriter("./logs", sess.graph_def)

        # saver
        saver = tf.train.Saver(tf.all_variables())

        # run
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        # sampling
        sample_z = np.random.uniform(-1, 1, size=(SAMPLE_SIZE, Z_DIM))

        # sample images
        #sample_images = dataset.get_sample()

        counter = 1
        start_time = time.time()

        for epoch in xrange(EPOCHS):
            for idx in xrange(0, dataset.batch_idxs):
                #batch_images = dataset.create_batch()
                batch_z = np.random.uniform(-1, 1, [BATCH_SIZE, Z_DIM]).astype(np.float32)

                # discriminative
                images_inf_eval, _, summary_str = sess.run([images_inf, d_optim, d_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                #for i, image_inf in enumerate(images_inf_eval):
                #    print np.uint8(image_inf)
                #    print image_inf.shape
                #    #image_inf_reshape = image_inf.reshape([64, 64, 3])
                #    img = Image.fromarray(np.asarray(image_inf), 'RGB')
                #    print img
                #    img.save('verbose/%d_%d.png' % (counter, i))

                # generative
                _, summary_str = sess.run([g_optim, g_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                # twice optimization
                _, summary_str = sess.run([g_optim, g_sum], {z: batch_z})
                writer.add_summary(summary_str, counter)

                errD_fake = sess.run(d_loss_fake, {z: batch_z})
                errD_real = sess.run(d_loss_real, {z: batch_z})
                errG = sess.run(g_loss, {z: batch_z})

                print("epochs: %02d %04d/%04d time: %4.4f, d_loss: %.8f, g_loss: %.8f" % (epoch, idx, dataset.batch_idxs,time.time() - start_time, errD_fake + errD_real, errG))

                if np.mod(counter, 10) == 1:
                    print("generate samples.")
                    generated_image_eval = sess.run(generate_images, {z: batch_z})
                    filename = os.path.join(FLAGS.sample_dir, 'out_%05d.png' % counter)
                    with open(filename, 'wb') as f:
                        f.write(generated_image_eval)
                counter += 1
        coord.request_stop()
        coord.join(threads)
        sess.close()
コード例 #36
0
ファイル: main.py プロジェクト: jeromeyoon/GAN_IR
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if FLAGS.is_train:
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, input_size=FLAGS.input_size,
                      dataset_name=FLAGS.dataset,
                      is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
        else:
   
            dcgan = EVAL(sess, input_size = 600, batch_size=1,ir_image_shape=[600,800,1],normal_image_shape=[600,800,3],dataset_name=FLAGS.dataset,\
                      is_crop=False, checkpoint_dir=FLAGS.checkpoint_dir)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)
            OPTION = 2 # for validation
            list_val = [11,16,21,22,33,36,38,53,59,92]
            VAL_OPTION =2
            """
            if OPTION == 1:
                data = json.load(open("/research2/IR_normal_small/json/traininput_single_224_ori_small.json"))
                data_label = json.load(open("/research2/IR_normal_small/json/traingt_single_224_ori_small.json"))
            
            elif OPTION == 2:
                data = json.load(open("/research2/IR_normal_small/json/testinput_single_224_ori_small.json"))
                data_label = json.load(open("/research2/IR_normal_small/json/testgt_single_224_ori_small.json"))
            """
            if VAL_OPTION ==1:
                list_val = [11,16,21,22,33,36,38,53,59,92]
                for idx in range(len(list_val)):
                    for idx2 in range(1,10): 
                        print("Selected material %03d/%d" % (list_val[idx],idx2))
                        img = '/research2/IR_normal_small/save%03d/%d' % (list_val[idx],idx2)
                        input_ = scipy.misc.imread(img+'/3.bmp').astype(float)
                        gt_ = scipy.misc.imread('/research2/IR_normal_small/save016/1/12_Normal.bmp').astype(float)
                        input_ = scipy.misc.imresize(input_,[600,800])
                        gt_ = scipy.misc.imresize(gt_,[600,800])
                        #input_ = input_[240:840,515:1315]
                        #gt_ = gt_[240:840,515:1315]
                        input_ = np.reshape(input_,(1,600,800,1)) 
                        gt_ = np.reshape(gt_,(1,600,800,3)) 
                        input_ = np.array(input_).astype(np.float32)
                        gt_ = np.array(gt_).astype(np.float32)
                        start_time = time.time() 
                        sample = sess.run(dcgan.sampler, feed_dict={dcgan.ir_images: input_})
                        print('time: %.8f' %(time.time()-start_time))     
                        # normalization #
                        sample = np.squeeze(sample).astype(np.float32)
                        gt_ = np.squeeze(gt_).astype(np.float32)

                        output = np.zeros((600,800,3)).astype(np.float32)
                        output[:,:,0] = sample[:,:,0]/(np.sqrt(np.power(sample[:,:,0],2) + np.power(sample[:,:,1],2) + np.power(sample[:,:,2],2)))
                        output[:,:,1] = sample[:,:,1]/(np.sqrt(np.power(sample[:,:,0],2) + np.power(sample[:,:,1],2) + np.power(sample[:,:,2],2)))
                        output[:,:,2] = sample[:,:,2]/(np.sqrt(np.power(sample[:,:,0],2) + np.power(sample[:,:,1],2) + np.power(sample[:,:,2],2)))
   
                        output[output ==inf] = 0.0
                        sample = (output+1.)/2.
                        savename = '/home/yjyoon/Dropbox/ECCV16_IRNormal/single_result/%03d/%d/single_normal_L2ang.bmp' % (list_val[idx],idx2)

                        scipy.misc.imsave(savename, sample)

            
            elif VAL_OPTION ==2:
                print("Computing all validation set ")
                ErrG =0.0
		num_img =13
                for idx in xrange(5, num_img+1):
                    print("[Computing Validation Error %d/%d]" % (idx, num_img))
                    img = '/home/yjyoon/Dropbox/ECCV16_IRNormal/extra/extra_%d.bmp' % (idx)
                    input_ = scipy.misc.imread(img).astype(float)
                    input_ = input_[:,:,0]
                    gt_ = scipy.misc.imread('/research2/IR_normal_small/save016/1/12_Normal.bmp').astype(float)
                    input_ = scipy.misc.imresize(input_,[600,800])
                    gt_ = scipy.misc.imresize(gt_,[600,800])
                    input_ = np.reshape(input_,(1,600,800,1)) 
                    gt_ = np.reshape(gt_,(1,600,800,3)) 
                    input_ = np.array(input_).astype(np.float32)
                    gt_ = np.array(gt_).astype(np.float32)
                    start_time = time.time() 
                    sample = sess.run(dcgan.sampler, feed_dict={dcgan.ir_images: input_})
                    print('time: %.8f' %(time.time()-start_time))     
                    # normalization #
                    sample = np.squeeze(sample).astype(np.float32)
                    gt_ = np.squeeze(gt_).astype(np.float32)

                    output = np.zeros((600,800,3)).astype(np.float32)
                    output[:,:,0] = sample[:,:,0]/(np.sqrt(np.power(sample[:,:,0],2) + np.power(sample[:,:,1],2) + np.power(sample[:,:,2],2)))
                    output[:,:,1] = sample[:,:,1]/(np.sqrt(np.power(sample[:,:,0],2) + np.power(sample[:,:,1],2) + np.power(sample[:,:,2],2)))
                    output[:,:,2] = sample[:,:,2]/(np.sqrt(np.power(sample[:,:,0],2) + np.power(sample[:,:,1],2) + np.power(sample[:,:,2],2)))
   
                    output[output ==inf] = 0.0
                    sample = (output+1.)/2.
                    savename = '/home/yjyoon/Dropbox/ECCV16_IRNormal/extra/extra_result%d.bmp' % (idx)

                    scipy.misc.imsave(savename, sample)
コード例 #37
0
ファイル: main.py プロジェクト: amoliu/DCGAN-tensorflow
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)
    if not os.path.exists(FLAGS.sample_dir):
        os.makedirs(FLAGS.sample_dir)

    with tf.Session() as sess:
        if FLAGS.dataset == 'mnist':
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, y_dim=10,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)
        else:
            dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                    dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir)

        if FLAGS.is_train:
            dcgan.train(FLAGS)
        else:
            dcgan.load(FLAGS.checkpoint_dir)

        to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
                                      [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
                                      [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
                                      [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
                                      [dcgan.h4_w, dcgan.h4_b, None])

        # Below is codes for visualization
        OPTION = 2
        if OPTION == 0:
          z_sample = np.random.uniform(-0.5, 0.5, size=(FLAGS.batch_size, dcgan.z_dim))
          samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
          save_images(samples, [8, 8], './samples/test_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime()))
        elif OPTION == 1:
          values = np.arange(0, 1, 1./FLAGS.batch_size)
          for idx in xrange(100):
            print(" [*] %d" % idx)
            z_sample = np.zeros([FLAGS.batch_size, dcgan.z_dim])
            for kdx, z in enumerate(z_sample):
              z[idx] = values[kdx]

            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
            save_images(samples, [8, 8], './samples/test_arange_%s.png' % (idx))
        elif OPTION == 2:
          values = np.arange(0, 1, 1./FLAGS.batch_size)
          for idx in [random.randint(0, 99) for _ in xrange(100)]:
            print(" [*] %d" % idx)
            z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
            z_sample = np.tile(z, (FLAGS.batch_size, 1))
            #z_sample = np.zeros([FLAGS.batch_size, dcgan.z_dim])
            for kdx, z in enumerate(z_sample):
              z[idx] = values[kdx]

            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
            make_gif(samples, './samples/test_gif_%s.gif' % (idx))
        elif OPTION == 3:
          values = np.arange(0, 1, 1./FLAGS.batch_size)
          for idx in xrange(100):
            print(" [*] %d" % idx)
            z_sample = np.zeros([FLAGS.batch_size, dcgan.z_dim])
            for kdx, z in enumerate(z_sample):
              z[idx] = values[kdx]

            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
            make_gif(samples, './samples/test_gif_%s.gif' % (idx))
        elif OPTION == 4:
          image_set = []
          values = np.arange(0, 1, 1./FLAGS.batch_size)

          for idx in xrange(100):
            print(" [*] %d" % idx)
            z_sample = np.zeros([FLAGS.batch_size, dcgan.z_dim])
            for kdx, z in enumerate(z_sample): z[idx] = values[kdx]

            image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
            make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))

          new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) for idx in range(64) + range(63, -1, -1)]
          make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8)
        elif OPTION == 5:
          image_set = []
          values = np.arange(0, 1, 1./FLAGS.batch_size)
          z_idx = [[random.randint(0,99) for _ in xrange(5)] for _ in xrange(200)]

          for idx in xrange(200):
            print(" [*] %d" % idx)
            #z_sample = np.zeros([FLAGS.batch_size, dcgan.z_dim])
            z = np.random.uniform(-1e-1, 1e-1, size=(dcgan.z_dim))
            z_sample = np.tile(z, (FLAGS.batch_size, 1))

            for kdx, z in enumerate(z_sample):
              for jdx in xrange(5):
                z_sample[kdx][z_idx[idx][jdx]] = values[kdx]

            image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
            make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))

          new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 20]) for idx in range(64) + range(63, -1, -1)]
          make_gif(new_image_set, './samples/test_gif_random_merged.gif', duration=4)
        elif OPTION == 6:
          image_set = []

          values = np.arange(0, 1, 1.0/FLAGS.batch_size).tolist()
          z_idx = [[random.randint(0,99) for _ in xrange(10)] for _ in xrange(100)]

          for idx in xrange(100):
            print(" [*] %d" % idx)
            z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
            z_sample = np.tile(z, (FLAGS.batch_size, 1))

            for kdx, z in enumerate(z_sample):
              for jdx in xrange(10):
                z_sample[kdx][z_idx[idx][jdx]] = values[kdx]

            image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
            save_images(image_set[-1], [8, 8], './samples/test_random_arange_%s.png' % (idx))

          new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) for idx in range(64) + range(63, -1, -1)]
          make_gif(new_image_set, './samples/test_gif_merged_random.gif', duration=4)
        elif OPTION == 7:
          for _ in xrange(50):
            z_idx = [[random.randint(0,99) for _ in xrange(10)] for _ in xrange(8)]

            zs = []
            for idx in xrange(8):
              z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
              zs.append(np.tile(z, (8, 1)))

            z_sample = np.concatenate(zs)
            values = np.arange(0, 1, 1/8.)

            for idx in xrange(FLAGS.batch_size):
              for jdx in xrange(8):
                z_sample[idx][z_idx[idx/8][jdx]] = values[idx%8]

            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
            save_images(samples, [8, 8], './samples/multiple_testt_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime()))
        elif OPTION == 8:
          counter = 0
          for _ in xrange(50):
            import scipy.misc
            z_idx = [[random.randint(0,99) for _ in xrange(10)] for _ in xrange(8)]

            zs = []
            for idx in xrange(8):
              z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim))
              zs.append(np.tile(z, (8, 1)))

            z_sample = np.concatenate(zs)
            values = np.arange(0, 1, 1/8.)

            for idx in xrange(FLAGS.batch_size):
              for jdx in xrange(8):
                z_sample[idx][z_idx[idx/8][jdx]] = values[idx%8]

            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
            for sample in samples:
              scipy.misc.imsave('./samples/turing/%s.png' % counter, sample)
              counter += 1
        else:
          import scipy.misc
          from glob import glob

          samples = []
          fnames = glob("/Users/carpedm20/Downloads/x/1/*.png")
          fnames = sorted(fnames, key = lambda x: int(x.split("_")[1]) * 10000 + int(x.split('_')[2].split(".")[0]))
          for f in fnames:
            samples.append(scipy.misc.imread(f))
          make_gif(samples, './samples/training.gif', duration=8, true_image=True)
コード例 #38
0
ファイル: main.py プロジェクト: rrichajalota/DCGAN-tensorflow
def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          c_dim=1,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          is_crop=FLAGS.is_crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          c_dim=FLAGS.c_dim,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          is_crop=FLAGS.is_crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)

    show_all_variables()
    if FLAGS.is_train:
      dcgan.train(FLAGS)
    else:
      if not dcgan.load(FLAGS.checkpoint_dir):
        raise Exception("[!] Train a model first, then run test mode")
      

    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])

    # Below is codes for visualization
    OPTION = 1
    visualize(sess, dcgan, FLAGS, OPTION)
コード例 #39
0
from model import DCGAN
from utils import pp, visualize, to_json

import tensorflow as tf

flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("image_size", 64, "The size of image to use")
flags.DEFINE_string("dataset", "lfw-aligned-64", "Dataset directory.")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
FLAGS = flags.FLAGS

if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
                  is_crop=False, checkpoint_dir=FLAGS.checkpoint_dir)

    dcgan.train(FLAGS)