示例#1
0
def train():
    is_training_v = tf.placeholder(tf.bool, shape=[])
    phv = {'is_training_v': is_training_v}
    with tf.Session() as sess:
        fa_images = V_graph(sess, phv)
        saver = tf.train.Saver(Genc.vars + Gdec.vars)
        init = tf.global_variables_initializer()
        sess.run(init)
        saver.restore(sess, LOG_PATH + '/weight.ckpt')

        #####################  generate data  #####################
        v_dict = {phv['is_training_v']: True}
        cnt = 0
        for i in range(train_iters):
            sample = sess.run(fa_images, feed_dict=v_dict)
            sample = util.to_uint8(sample)
            for s in sample:
                cnt += 1
                io.imsave('%s/%06d.jpg' % (OUT_PATH, cnt), s, quality=95)
示例#2
0
def saveTextures(outData, textures, out_path, loadFromDirectory):
    """
    Save the textures to the specified path.
    Inputs:
        outData - height x width x depth array for the output texture(s)
        textures - list of InputTextureFile objects
        out_path - path to output
        loadFromDirectory - is the output path a directory?
    Output:
        Saves textures.
    """
    assert len(outData.shape) == 3

    out_dir = os.path.dirname(out_path)
    if (out_dir != "" and not os.path.exists(out_dir)):
        os.makedirs(out_dir)

    current_depth = 0
    for textureFile in textures:
        next_depth = current_depth + textureFile.depth

        textureData = outData[:, :, current_depth:next_depth]
        if (textureData.shape[2] < 2):
            textureData = numpy.squeeze(textureData, axis=2)

        if (not textureFile.isFloat):
            textureData = to_uint8(textureData, normalize=False)

        # Save the solved texture
        if (loadFromDirectory):
            base, ext = os.path.splitext(textureFile.name)
            out_texture = os.path.join(out_path, base + "-erased" + ext)
        else:
            out_texture = out_path

        if (textureFile.isDataFile):
            weight_data.write_tex_to_path(out_texture, textureData)
        else:
            texture.save_texture(textureData, out_texture)

        current_depth = next_depth
示例#3
0
def train():
  lr_g = tf.placeholder(tf.float32, shape=[])
  lr_d = tf.placeholder(tf.float32, shape=[])
  is_training_g = tf.placeholder(tf.bool, shape=[])
  is_training_d = tf.placeholder(tf.bool, shape=[])
  is_training_v = tf.placeholder(tf.bool, shape=[])
  is_training_a = tf.placeholder(tf.bool, shape=[])
  phg = {'lr_g': lr_g,
       'is_training_g': is_training_g}
  phd = {'lr_d': lr_d,
       'is_training_d': is_training_d}
  phv = {'is_training_v': is_training_v}
  pha = {'is_training_a': is_training_a}
  with tf.Session() as sess:
    G_u, G_loss, G_update = G_graph(sess, phg)
    Ga_u, Ga_loss, Ga_acc, Ga_update = G_adv_graph(sess, phg)
    D_loss, D_update  = D_graph(sess, phd)
    tr_images, re_images, fa_images = V_graph(sess, phv)
    adv_loss, adv_acc, adv_update, adv_pred = Adv_graph(sess, pha)
    saver = tf.train.Saver(Genc.vars+Gdec.vars)
    init = tf.global_variables_initializer()
    sess.run(init)
    # saver.restore(sess,LOG_PATH+'/weight.ckpt')

    #####################  pre-train Generator & Discriminator  #####################
    g_dict = {phg['lr_g']: 1e-4,
          phg['is_training_g']: True}
    d_dict = {phd['lr_d']: 1e-4,
          phd['is_training_d']: True}
    for epoch in range(N_EPOCH):
      for i in range(train_iters):
        if i % 6 !=0:
          d_loss, _ = sess.run([D_loss, D_update], feed_dict=d_dict)
        else:
          g_loss, _ = sess.run([G_loss, G_update], feed_dict=g_dict)

        if i % 100 == 0 and i != 0:
          log_string('====pre_epoch_%d====iter_%d: g_loss=%.4f, d_loss=%.4f' % (epoch, i, g_loss, d_loss))
    saver.save(sess,LOG_PATH+'/pretrain.ckpt')

    #####################  train and fix adversarial classifier  #####################
    a_dict = {pha['is_training_a']: True}
    for epoch in range(N_ADV):
      for i in range(train_iters):
        a_loss, a_acc, _ = sess.run([adv_loss, adv_acc, adv_update], feed_dict=a_dict)
        if i % 100 == 0 and i != 0:
          log_string('====adv_epoch_%d====iter_%d: a_loss=%.4f, a_acc=%.4f' % (epoch, i, a_loss, a_acc))

    #####################  fine-tune Generator & Discriminator  #####################
    g_dict = {phg['lr_g']: 5e-6,
          phg['is_training_g']: True}
    d_dict = {phd['lr_d']: 5e-6,
          phd['is_training_d']: True}
    v_dict = {phv['is_training_v']: True}
    for epoch in range(N_EPOCH):
      for i in range(train_iters):
        if i % 6 !=0:
          d_loss, _ = sess.run([D_loss, D_update], feed_dict=d_dict)
        else:
          g_loss, a_acc, _ = sess.run([Ga_loss, Ga_acc, Ga_update], feed_dict=g_dict)

        if i % 100 == 0 and i != 0:
          log_string('====epoch_%d====iter_%d: g_loss=%.4f, d_loss=%.4f, a_acc=%.4f' % (epoch, i, g_loss, d_loss, a_acc))
      a,b,c = sess.run([tr_images, re_images, fa_images], feed_dict = v_dict)
      image_list = [a,b,c]
      sample = np.transpose(image_list, (1, 2, 0, 3, 4))
      sample = np.reshape(sample, (-1, sample.shape[2] * sample.shape[3], sample.shape[4]))
      sample = util.to_uint8(sample)
      io.imsave('%s/Epoch-%d.jpg' % (LOG_PATH, epoch), sample, quality=95)
      saver.save(sess,LOG_PATH+'/weight.ckpt')