示例#1
0
def main(_):

    if FLAGS.input_width is None:
        FLAGS.input_width = FLAGS.input_height
    if FLAGS.output_width is None:
        FLAGS.output_width = FLAGS.output_height

    file_name = time.strftime("%Y_%m_%d_%H%M", time.localtime())
    if FLAGS.unreg:
        file_name += "_unreg_dcgan"
    else:
        file_name += "_regularized_dcgan_" + str(FLAGS.gamma) + "gamma"
    if FLAGS.annealing:
        file_name += "_annealing_" + str(FLAGS.decay_factor) + "decayfactor"
    if FLAGS.rmsprop:
        file_name += "_rmsprop"
    else:
        file_name += "_adam"
    file_name += "_" + str(FLAGS.disc_update_steps) + "dsteps"
    file_name += "_" + str(FLAGS.disc_learning_rate) + "dlnr"
    file_name += "_" + str(FLAGS.gen_learning_rate) + "glnr"
    file_name += "_" + str(FLAGS.epochs) + "epochs"
    file_name += "_" + str(FLAGS.dataset)

    log_dir = os.path.abspath(os.path.join(FLAGS.root_dir, file_name))

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    pp.pprint(flags.FLAGS.__flags)

    with tf.Session() as sess:

        dcgan = DCGAN(sess, log_dir, FLAGS)

        show_all_variables()

        if FLAGS.checkpoint_dir is not "None":
            if not dcgan.load_ckpt()[0]:
                raise Exception("[!] ERROR: provide valid checkpoint_dir")
        else:
            starttime = time.time()

            dcgan.train(FLAGS)

            endtime = time.time()
            print('Total Train Time: {:.2f}'.format(endtime - starttime))

        dcgan.generate(FLAGS, option=1)

    file = open(os.path.join(log_dir, "flags.json"), 'a')
    json.dump(vars(FLAGS), file)
    file.close()
示例#2
0
文件: main.py 项目: yunfanz/VFG
def main(_):
    STARTED_DATESTRING = "{0:%Y-%m-%dT%H:%M:%S}".format(
        datetime.now()).replace(":", "-")
    pp.pprint(flags.FLAGS.__flags)

    assert FLAGS.mode.lower() in (
        'train', 'generate'), "mode must be 'train' or 'generate'!"
    FLAGS.mode = FLAGS.mode.lower()
    if FLAGS.mode == 'train':
        if FLAGS.out_dir is None:
            FLAGS.out_dir = 'out/train_' + STARTED_DATESTRING
            print('Using default out_dir {0}'.format(FLAGS.out_dir))
        else:
            if FLAGS.out_dir.endswith('/'): FLAGS.out_dir = FLAGS.out_dir[:-1]
        if FLAGS.checkpoint_dir is None:
            FLAGS.checkpoint_dir = FLAGS.out_dir + '/checkpoint'
    else:
        if FLAGS.checkpoint_dir is None:
            raise Exception(
                'Cannot generate: checkpoint {0} does not exist!'.format(
                    FLAGS.checkpoint_dir))
        else:
            if FLAGS.checkpoint_dir.endswith('/'):
                FLAGS.checkpoint_dir = FLAGS.checkpoint_dir[:-1]
        if FLAGS.out_dir is None:
            FLAGS.out_dir = 'out/gene_' + STARTED_DATESTRING

    if not os.path.exists(FLAGS.out_dir):
        os.makedirs(FLAGS.out_dir)
        #import IPython; IPython.embed()
        if FLAGS.mode == 'train':
            os.makedirs(FLAGS.out_dir + '/samples')
            os.makedirs(FLAGS.out_dir + '/checkpoint')
            os.makedirs(FLAGS.out_dir + '/logs')

    if FLAGS.audio_params is None:
        if FLAGS.mode == 'train':
            FLAGS.audio_params = './audio_params.json'
            copyfile(FLAGS.audio_params,
                     FLAGS.checkpoint_dir + '/audio_params.json')
        else:
            print('Using json file from {0}'.format(FLAGS.checkpoint_dir))
            FLAGS.audio_params = FLAGS.checkpoint_dir + '/audio_params.json'

    with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
        #G
        if FLAGS.dataset == 'wav':
            with open('audio_params.json', 'r') as f:
                audio_params = json.load(f)
            FLAGS.epoch = audio_params['epoch']
            FLAGS.learning_rate = audio_params['learning_rate']
            FLAGS.beta1 = audio_params['beta1']
            FLAGS.sample_length = audio_params['sample_length']
            dcgan = DCGAN(sess,
                          batch_size=FLAGS.batch_size,
                          z_dim=audio_params['z_dim'],
                          sample_length=FLAGS.sample_length,
                          c_dim=1,
                          dataset_name=FLAGS.dataset,
                          audio_params=FLAGS.audio_params,
                          data_dir=FLAGS.data_dir,
                          use_disc=FLAGS.use_disc,
                          use_fourier=FLAGS.use_fourier,
                          run_g=FLAGS.run_g,
                          checkpoint_dir=FLAGS.checkpoint_dir,
                          out_dir=FLAGS.out_dir,
                          mode=FLAGS.mode)
        else:
            raise Exception('dataset not understood')

        if FLAGS.mode == 'train':
            dcgan.train(FLAGS)
        else:
            print('Generating {0} batches of size {1} from checkpoint {2}'.
                  format(FLAGS.gen_size, FLAGS.batch_size,
                         FLAGS.checkpoint_dir))
            dcgan.load(FLAGS.checkpoint_dir)
            dcgan.generate(FLAGS)

        if FLAGS.visualize:
            to_json("./web/js/layers.js",
                    [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
                    [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
                    [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
                    [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
                    [dcgan.h4_w, dcgan.h4_b, None])

            # Below is codes for visualization
            OPTION = 2
            visualize(sess, dcgan, FLAGS, OPTION)