示例#1
0
def main(_):
  FLAGS.concat_y_layers = [int(x) for x in FLAGS.concat_y_layers]

  if FLAGS.dir_prefix is None:
    FLAGS.dir_prefix = ''
  else:
    FLAGS.dir_prefix = FLAGS.dir_prefix + '_'

  if FLAGS.checkpoint is None:
    FLAGS.checkpoint_dir = os.path.join(
      FLAGS.checkpoint_dir, FLAGS.dir_prefix + FLAGS.algorithm + "_" + str(FLAGS.alpha) + "_" + FLAGS.disc_type + 
      "_" + datetime.now().strftime("%Y%m%d-%H%M%S"))
  else:
    FLAGS.checkpoint_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.checkpoint)
  FLAGS.sample_dir = os.path.join(FLAGS.checkpoint_dir, 'samples/') 

  pp.pprint(flags.FLAGS.__flags)
  FLAGS.input_height = 28
  FLAGS.output_height = 28
  if FLAGS.input_width is None:
    FLAGS.input_width = 28#FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = 28#FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  file_list = ['main.py', 'model.py', 'utils.py', 'ops.py', 'sn.py']
  utils.dump_script(FLAGS.checkpoint_dir, FLAGS.script_file, file_list=file_list)

  if FLAGS.logs_at_ckpt:
    FLAGS.logs_dir = FLAGS.checkpoint_dir

  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True
  FLAGS.dataset = 'mnist'
  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.z_dim,
          dataset_name=FLAGS.dataset,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          data_dir=FLAGS.data_dir,
          algorithm=FLAGS.algorithm,
          estimate_confuse=FLAGS.estimate_confuse,
          perm_regularizer=FLAGS.perm_regularizer,
          alpha=FLAGS.alpha,
          disc_type=FLAGS.disc_type,
          add_noise=FLAGS.add_noise,
          noise_alpha=FLAGS.noise_alpha,
          config=FLAGS)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)
    else:
      if not dcgan.load(FLAGS.checkpoint_dir)[0]:
        print("[!] Training a model first, then run test mode")
        dcgan.train(FLAGS)
    
    dcgan.recover_labels(FLAGS)