def main(_): FLAGS.concat_y_layers = [int(x) for x in FLAGS.concat_y_layers] if FLAGS.dir_prefix is None: FLAGS.dir_prefix = '' else: FLAGS.dir_prefix = FLAGS.dir_prefix + '_' if FLAGS.checkpoint is None: FLAGS.checkpoint_dir = os.path.join( FLAGS.checkpoint_dir, FLAGS.dir_prefix + FLAGS.algorithm + "_" + str(FLAGS.alpha) + "_" + FLAGS.disc_type + "_" + datetime.now().strftime("%Y%m%d-%H%M%S")) else: FLAGS.checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.checkpoint) FLAGS.sample_dir = os.path.join(FLAGS.checkpoint_dir, 'samples/') pp.pprint(flags.FLAGS.__flags) FLAGS.input_height = 28 FLAGS.output_height = 28 if FLAGS.input_width is None: FLAGS.input_width = 28#FLAGS.input_height if FLAGS.output_width is None: FLAGS.output_width = 28#FLAGS.output_height if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) file_list = ['main.py', 'model.py', 'utils.py', 'ops.py', 'sn.py'] utils.dump_script(FLAGS.checkpoint_dir, FLAGS.script_file, file_list=file_list) if FLAGS.logs_at_ckpt: FLAGS.logs_dir = FLAGS.checkpoint_dir run_config = tf.ConfigProto() run_config.gpu_options.allow_growth=True FLAGS.dataset = 'mnist' with tf.Session(config=run_config) as sess: if FLAGS.dataset == 'mnist': dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, y_dim=10, z_dim=FLAGS.z_dim, dataset_name=FLAGS.dataset, crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, data_dir=FLAGS.data_dir, algorithm=FLAGS.algorithm, estimate_confuse=FLAGS.estimate_confuse, perm_regularizer=FLAGS.perm_regularizer, alpha=FLAGS.alpha, disc_type=FLAGS.disc_type, add_noise=FLAGS.add_noise, noise_alpha=FLAGS.noise_alpha, config=FLAGS) show_all_variables() if FLAGS.train: dcgan.train(FLAGS) else: if not dcgan.load(FLAGS.checkpoint_dir)[0]: print("[!] Training a model first, then run test mode") dcgan.train(FLAGS) dcgan.recover_labels(FLAGS)