def main(_): pp.pprint(flags.FLAGS.__flags) if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: if FLAGS.dataset == 'mnist': assert False dcgan = DCGAN( sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, sample_size=16, z_dim=8192, d_label_smooth=.25, generator_target_prob=.75 / 2., out_stddev=.075, out_init_b=-.45, image_shape=[FLAGS.image_width, FLAGS.image_width, 3], dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, generator=Generator(), train_func=train, discriminator_func=discriminator, predictor_func=predictor, #自己加的一个预测的函数集 build_model_func=build_model, config=FLAGS, devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"] ) if FLAGS.is_train: print("TRAINING") dcgan.train(FLAGS) print("DONE TRAINING") else: # dcgan.load(FLAGS.checkpoint_dir)#以前的 dcgan.predictor(FLAGS) OPTION = 2