def main():

    tf.random.set_seed(22)
    np.random.seed(22)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    assert tf.__version__.startswith('2.')

    # hyper parameters
    z_dim = 100
    epochs = 3000000
    batch_size = 512
    learning_rate = 0.002
    is_training = True

    img_path = glob.glob('/Users/tongli/Desktop/Python/TensorFlow/faces/*.jpg')

    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
    print(dataset, img_shape)
    sample = next(iter(dataset))
    print(sample.shape,
          tf.reduce_max(sample).numpy(),
          tf.reduce_min(sample).numpy())
    dataset = dataset.repeat()
    db_iter = iter(dataset)

    generator = Generator()
    generator.build(input_shape=(None, z_dim))
    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))

    g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

    for epoch in range(epochs):

        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        batch_x = next(db_iter)

        # train D
        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x,
                               is_training)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(
            zip(grads, discriminator.trainable_variables))

        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))

            z = tf.random.uniform([100, z_dim])
            fake_image = generator(z, training=False)
            img_path = os.path.join('images', 'gan-%d.png' % epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')
Example #2
0
def main():
    tf.random.set_seed(22)
    np.random.seed(22)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    assert tf.__version__.startswith('2.')

    # 设置超参数
    z_dim = 10
    epoch = 3000000
    batch_size = 512
    learning_rate = 0.002
    is_training = True

    # 数据集加载,每一张图片路径集
    img_path = glob.glob(r'D:\PyCharm Projects\CarNum-CNN\data\faces\*.jpg')

    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
    # print(dataset, img_shape)
    # sample = next(iter(dataset))
    # print(sample)

    # 无线采样
    dataset = dataset.repeat()
    db_iter = iter(dataset)

    # 导入生成器模型和判断器模型
    genertor = Generator()
    genertor.build(input_shape=(None, z_dim))
    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))

    # 分别设置两个优化器
    g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

    for epoch in range(epoch):

        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        batch_x = next(db_iter)

        # train D
        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(genertor, discriminator, batch_z, batch_x,
                               is_training)

        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(
            zip(grads, discriminator.trainable_variables))

        # train G
        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(genertor, discriminator, batch_z, is_training)

        grads = tape.gradient(g_loss, genertor.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, genertor.trainable_variables))

        # 打印
        if epoch % 100 == 0:
            print(epoch, "d_loss:", float(d_loss), "g_loss:", float(g_loss))

            z = tf.random.uniform([100, z_dim])
            fake_image = genertor(z, training=False)
            img_path = os.path.join('./gan_images', 'gan-%d.png' % epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')