Exemple #1
0
def main():

    tf.random.set_seed(233)
    np.random.seed(233)
    assert tf.__version__.startswith('2.')

    # hyper parameters
    z_dim = 100
    epochs = 1
    batch_size = 64
    learning_rate = 0.0005
    is_training = True
    k = 5

    root = os.path.dirname(os.path.abspath(__file__))
    save_path = os.path.join(root, 'gan-images')

    img_path = glob.glob('faces/*.jpg')
    assert len(img_path) > 0

    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
    print(dataset, img_shape)
    sample = next(iter(dataset))
    print(sample.shape,
          tf.reduce_max(sample).numpy(),
          tf.reduce_min(sample).numpy())
    dataset = dataset.repeat()
    db_iter = iter(dataset)

    generator = Generator()
    generator.build(input_shape=(None, z_dim))
    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))

    z_sample = tf.random.normal([100, z_dim])
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                           beta_1=0.5)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                           beta_1=0.5)

    for epoch in range(epochs):
        time_start = datetime.datetime.now()
        for _ in range(k):
            batch_z = tf.random.normal([batch_size, z_dim])
            batch_x = next(db_iter)

            # train D
            with tf.GradientTape() as tape:
                d_loss, gp = d_loss_fn(generator, discriminator, batch_z,
                                       batch_x, is_training)
            grads = tape.gradient(d_loss, discriminator.trainable_variables)
            d_optimizer.apply_gradients(
                zip(grads, discriminator.trainable_variables))

        batch_z = tf.random.normal([batch_size, z_dim])

        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        # Epoch: 0/1 TimeUsed: 0:00:03.809766 d-loss: 0.11730888 g-loss: -0.40536720 gp: 0.01527955
        print(
            f"Epoch: {epoch}/{epochs} TimeUsed: {datetime.datetime.now()-time_start} d-loss: {d_loss:.8f} g-loss: {g_loss:.8f} gp: {gp:.8f}"
        )

        if epoch % 100 == 0:
            z = tf.random.normal([100, z_dim])
            fake_image = generator(z, training=False)
            img_path = os.path.join(save_path, 'wgan-%d.png' % epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')
Exemple #2
0
def main():
    tf.random.set_seed(3333)
    np.random.seed(3333)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    assert tf.__version__.startswith('2.')
    z_dim = 100  # 隐藏向量z的长度
    epochs = 1  # 训练步数
    batch_size = 64  # batch size
    learning_rate = 0.0002
    is_training = True
    k = 5

    # Path
    root = os.path.dirname(os.path.abspath(__file__))
    model_path = os.path.join(root, 'models')
    generator_ckpt_path = os.path.join(model_path, 'generator',
                                       'generator.ckpt')
    discriminator_ckpt_path = os.path.join(model_path, 'discriminator',
                                           'discriminator.ckpt')
    save_image_path = os.path.join(root, 'gan_images')

    # 获取数据集路径
    img_path = glob.glob('faces/*.jpg')
    print('images num:', len(img_path))

    # 构建数据集对象
    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64)
    print(dataset, img_shape)

    sample = next(iter(dataset))  # 采样
    print(
        f"batch_shape: {sample.shape} max: {tf.reduce_max(sample).numpy()} min: {tf.reduce_min(sample).numpy()}"
    )

    dataset = dataset.repeat(100)  # 重复循环
    db_iter = iter(dataset)

    generator = Generator()  # 创建生成器
    generator.build(input_shape=(None, z_dim))
    discriminator = Discriminator()  # 创建判别器
    discriminator.build(input_shape=(None, 64, 64, 3))

    # 分别为生成器和判别器创建优化器
    g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate,
                                        beta_1=0.5)
    d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate,
                                        beta_1=0.5)

    if os.path.exists(generator_ckpt_path + '.index'):
        generator.load_weights(generator_ckpt_path)
        print('Loaded generator ckpt!!')
    if os.path.exists(discriminator_ckpt_path + '.index'):
        discriminator.load_weights(discriminator_ckpt_path)
        print('Loaded discriminator ckpt!!')

    d_losses, g_losses = [], []
    for epoch in range(epochs):  # 训练epochs次
        time_start = datetime.datetime.now()
        # 1. 训练判别器,训练k步后训练 generator
        for _ in range(k):
            # 采样隐藏向量
            batch_z = tf.random.normal([batch_size, z_dim])
            batch_x = next(db_iter)  # 采样真实图片
            # 判别器前向计算
            with tf.GradientTape() as tape:
                d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x,
                                   is_training)
            grads = tape.gradient(d_loss, discriminator.trainable_variables)
            d_optimizer.apply_gradients(
                zip(grads, discriminator.trainable_variables))

        # 2. 训练生成器
        # 采样隐藏向量
        batch_z = tf.random.normal([batch_size, z_dim])
        # 生成器前向计算
        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        # Epoch: 0/1 TimeUsed: 0:00:10.126834 d-loss: 1.45619345 g-loss: 0.63321948
        print(
            f"Epoch: {epoch}/{epochs} TimeUsed: {datetime.datetime.now()-time_start} d-loss: {d_loss:.8f} g-loss: {g_loss:.8f}"
        )

        if epoch % 100 == 0:
            z = tf.random.normal([100, z_dim])  # 可视化
            fake_image = generator(z, training=False)
            img_path = os.path.join(save_image_path, 'gan-%d.png' % epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')

            d_losses.append(float(d_loss))
            g_losses.append(float(g_loss))

            generator.save_weights(generator_ckpt_path)
            discriminator.save_weights(discriminator_ckpt_path)