Example #1
0
def meta_learn():
    k = 8
    frame_shape = h, w, c = (256, 256, 3)
    input_embedder_shape = (h, w, k * c)
    BATCH_SIZE = 12
    num_videos = 145008  # This is dividable by BATCH_SIZE. All data is 145520
    num_batches = num_videos // BATCH_SIZE
    epochs = 75
    datapath = '../few-shot-learning-of-talking-heads/datasets/voxceleb2-9f/train/lndmks'

    gan = GAN(input_shape=frame_shape, num_videos=num_videos, k=k)
    with tf.device("/cpu:0"):
        combined_to_train, combined, discriminator_to_train, discriminator = gan.compile_models(
            meta=True, gpus=0)
        embedder = gan.embedder
        generator = gan.generator
        intermediate_vgg19 = gan.intermediate_vgg19
        intermediate_vggface = gan.intermediate_vggface
        intermediate_discriminator = gan.intermediate_discriminator
        embedding_discriminator = gan.embedding_discriminator

    logger.info('==== discriminator ===')
    discriminator.summary(print_fn=logger.info)
    logger.info('=== generator ===')
    combined.get_layer('generator').summary(print_fn=logger.info)
    logger.info('=== embedder ===')
    combined.get_layer('embedder').summary(print_fn=logger.info)
    combined.summary(print_fn=logger.info)

    for epoch in range(epochs):
        logger.info(('Epoch: ', epoch))
        for batch_ix, (frames, landmarks, styles, condition) in enumerate(
                flow_from_dir(datapath, num_videos, (h, w), BATCH_SIZE, k)):
            if batch_ix == num_batches:
                break
            valid = np.ones((frames.shape[0], 1))
            invalid = -valid

            intermediate_vgg19_reals = intermediate_vgg19.predict_on_batch(
                frames)
            intermediate_vggface_reals = intermediate_vggface.predict_on_batch(
                frames)
            intermediate_discriminator_reals = intermediate_discriminator.predict_on_batch(
                [frames, landmarks])

            style_list = [styles[:, i, :, :, :] for i in range(k)]

            w_i = embedding_discriminator.predict_on_batch(condition)

            g_loss = combined_to_train.train_on_batch(
                [landmarks] + style_list + [condition],
                intermediate_vgg19_reals + intermediate_vggface_reals +
                [valid] + intermediate_discriminator_reals + [w_i] * k)

            d_loss_real = discriminator_to_train.train_on_batch(
                [frames, landmarks, condition], [valid])

            embeddings_list = [
                embedder.predict_on_batch(style) for style in style_list
            ]
            average_embedding = np.mean(np.array(embeddings_list), axis=0)
            fake_frames = generator.predict_on_batch(
                [landmarks, average_embedding])
            d_loss_fake = discriminator_to_train.train_on_batch(
                [fake_frames, landmarks, condition], [invalid])
            logger.info((epoch, batch_ix, g_loss, (d_loss_real, d_loss_fake)))

            if batch_ix % 100 == 0 and batch_ix > 0:
                # Save whole model
                # combined.save('trained_models/{}_meta_combined.h5'.format(epoch))
                # discriminator.save('trained_models/{}_meta_discriminator.h5'.format(epoch))

                # Save weights only
                #                combined.save_weights('trained_models/{}_meta_combined_weights.h5'.format(epoch))
                combined.get_layer('generator').save_weights(
                    'trained_models/{}_meta_generator_in_combined.h5'.format(
                        epoch))
                combined.get_layer('embedder').save_weights(
                    'trained_models/{}_meta_embedder_in_combined.h5'.format(
                        epoch))
                discriminator.save_weights(
                    'trained_models/{}_meta_discriminator_weights.h5'.format(
                        epoch))
                logger.info(
                    'Checkpoint saved at Epoch: {}; batch_ix: {}'.format(
                        epoch, batch_ix))
        print()
def fewshot_learn():
    metalearning_epoch = 0
    BATCH_SIZE = 1
    k = 1
    frame_shape = h, w, c = (256, 256, 3)
    input_embedder_shape = (h, w, k * c)
    BATCH_SIZE = 1
    num_videos = 1
    num_batches = 1
    epochs = 40
    dataname = 'monalisa'
    datapath = '../few-shot-learning-of-talking-heads/datasets/fewshot/' + dataname + '/lndmks'

    gan = GAN(input_shape=frame_shape, num_videos=num_videos, k=k)
    with tf.device("/cpu:0"):
        combined_to_train, combined, discriminator_to_train, discriminator = gan.compile_models(
            meta=False, gpus=0)
        embedder = gan.embedder
        generator = gan.generator
        intermediate_vgg19 = gan.intermediate_vgg19
        intermediate_vggface = gan.intermediate_vggface
        intermediate_discriminator = gan.intermediate_discriminator

    discriminator.load_weights(
        'trained_models/{}_meta_discriminator_weights.h5'.format(
            metalearning_epoch),
        by_name=True,
        skip_mismatch=True)
    combined.get_layer('embedder').load_weights(
        'trained_models/{}_meta_embedder_in_combined.h5'.format(
            metalearning_epoch))
    combined.get_layer('generator').load_weights(
        'trained_models/{}_meta_generator_in_combined.h5'.format(
            metalearning_epoch))

    for epoch in range(epochs):
        for batch_ix, (frames, landmarks, styles) in enumerate(
                flow_from_dir(datapath,
                              num_videos, (h, w),
                              BATCH_SIZE,
                              k,
                              meta=False)):
            if batch_ix == num_batches:
                break
            valid = np.ones((frames.shape[0], 1))
            invalid = -valid

            intermediate_vgg19_outputs = intermediate_vgg19.predict_on_batch(
                frames)
            intermediate_vggface_outputs = intermediate_vggface.predict_on_batch(
                frames)
            intermediate_discriminator_outputs = intermediate_discriminator.predict_on_batch(
                [frames, landmarks])

            style_list = [styles[:, i, :, :, :] for i in range(k)]
            embeddings_list = [
                embedder.predict_on_batch(style) for style in style_list
            ]
            average_embedding = np.mean(np.array(embeddings_list), axis=0)
            fake_frames = generator.predict_on_batch(
                [landmarks, average_embedding])

            g_loss = combined_to_train.train_on_batch(
                [landmarks] + style_list,
                intermediate_vgg19_outputs + intermediate_vggface_outputs +
                [valid] + intermediate_discriminator_outputs)

            embeddings_list = [
                embedder.predict_on_batch(style) for style in style_list
            ]
            average_embedding = np.mean(np.array(embeddings_list), axis=0)

            d_loss_real = discriminator_to_train.train_on_batch(
                [frames, landmarks, average_embedding], [valid])

            fake_frames = generator.predict_on_batch(
                [landmarks, average_embedding])

            d_loss_fake = discriminator_to_train.train_on_batch(
                [fake_frames, landmarks, average_embedding], [invalid])
            logger.info((epoch, batch_ix, g_loss, (d_loss_real, d_loss_fake)))

    # Save whole model


#    combined.save('trained_models/{}_fewshot_combined.h5'.format(dataname))
#    discriminator.save('trained_models/{}_fewshot_discriminator.h5'.format(dataname))

# Save weights only
#    combined.save_weights('trained_models/{}_fewshot_combined_weights.h5'.format(dataname))
    combined.get_layer('generator').save_weights(
        'trained_models/{}_fewshot_generator_in_combined.h5'.format(dataname))
    #    combined.get_layer('embedder').save_weights('trained_models/{}_fewshot_embedder_in_combined.h5'.format(dataname))
    discriminator.save_weights(
        'trained_models/{}_fewshot_discriminator_weights.h5'.format(dataname))