Esempio n. 1
0
def train(model_name: str,
          n_cr: int,
          num_workers: int = 0,
          is_test: bool = False,
          resume_from_checkpoint: str = None):
    seed_everything(SEED)

    if model_name == "improved_gan":
        config_function = get_gan_test_config if is_test else get_gan_default_config
        config = config_function(n_cr)
        model = GAN(config, num_workers, improved=True)
    elif model_name == "default_gan":
        config_function = get_gan_test_config if is_test else get_gan_default_config
        config = config_function(n_cr)
        model = GAN(config, num_workers, improved=False)
    else:
        raise ValueError(f"Model {model_name} is not supported")
    # define logger
    wandb_logger = WandbLogger(project="GAN", log_model=True, offline=is_test)
    wandb_logger.watch(model, log="all")
    # define model checkpoint callback
    model_checkpoint_callback = ModelCheckpoint(
        filepath=join(wandb.run.dir, "{epoch:02d}-{val_loss:.4f}"),
        period=config.save_every_epoch,
        save_top_k=3,
    )
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateLogger()
    trainer = Trainer(
        max_epochs=config.n_epochs,
        deterministic=True,
        check_val_every_n_epoch=config.val_every_epoch,
        row_log_interval=config.log_every_epoch,
        logger=wandb_logger,
        checkpoint_callback=model_checkpoint_callback,
        resume_from_checkpoint=resume_from_checkpoint,
        gpus=gpu,
        callbacks=[lr_logger],
        reload_dataloaders_every_epoch=True,
    )

    trainer.fit(model)

    trainer.test()
Esempio n. 2
0
def create_model(config, model_name, num_classes=1):
    section_training = config['TRAINING']
    RGB = section_training.getboolean('RGB') if 'RGB' in section_training else None
    input_size = tuple(map(int, section_training.get('input_size').split('x'))) if 'input_size' in section_training else None
    section_hyper = config['HYPERPARAMS']
    encode_factor = section_hyper.getint('encode_factor', None)
    base_channels = section_hyper.getint('base_channels', None)
    latent_dim = section_hyper.getint('latent_dim', 10)
    channel_increase_factor = section_hyper.getint('channel_increase_factor', 2)
    conv_blocks_per_decrease = section_hyper.getint('conv_blocks_per_decrease', 1)
    initial_upsample_size = section_hyper.getint('initial_upsample_size', 3)
    skip_connections = section_hyper.getboolean('skip_connections', False)
    num_classes = num_classes if config.getboolean('HYPERPARAMS', 'auxillary', fallback=False) else 0
    """
    if args.data == 'MNIST':
        input_size = (28,28)
        hidden_dim_size = (64, 2)
        encode_factor = 2
        base_channels = 32
    elif args.data == 'CelebA':
        input_size = (218, 178)
        hidden_dim_size = (64, 8)
        encode_factor = 4
        base_channels = 32
    else:
        raise NotImplementedError('Input size/ Hidden dimension for dataset not specified')
    """

    ## Init model
    model = None
    if model_name in ['AE', 'VAE']:
        model = AE.Autoencoder(variational = model_name == 'VAE', final_size=input_size, encode_factor=encode_factor, RGB = RGB, base_channels=base_channels, channel_increase_factor=channel_increase_factor, conv_blocks_per_decrease=conv_blocks_per_decrease, encoding_dimension=latent_dim, initial_upsample_size=initial_upsample_size, skip_connections=skip_connections, n_classes=num_classes)
    elif model_name == 'VanillaGAN':
        latent_dim = config.getint('HYPERPARAMS', 'latent_dim', fallback=10)
        #print('Latent dim was set to {:d}'.format(latent_dim))
        gen = GAN.VanillaGenerator(input_dim = latent_dim, num_classes=num_classes)
        disc = GAN.VanillaDiscriminator(n_classes=num_classes)
        model = (gen, disc)
    elif model_name == 'DCGAN':
        gen = GAN.DCGenerator(input_dim = latent_dim, num_classes=num_classes)
        disc = GAN.DCDiscriminator(n_classes=num_classes)
        model = (gen, disc)
    else:
        raise NotImplementedError('The model you specified is not implemented yet')
    return model
def predict(lndmk_image_paths, style_frame_paths, style_lndmk_paths):
    frame_shape = (256, 256, 3)
    lndmk_images = np.array(
        [load_image(path, frame_shape) for path in lndmk_image_paths])
    batch_size = lndmk_images.shape[0]
    style_frame_images = np.array(
        [load_image(path, frame_shape) for path in style_frame_paths])
    style_lndmk_images = np.array(
        [load_image(path, frame_shape) for path in style_lndmk_paths])
    style = np.concatenate((style_lndmk_images, style_frame_images), axis=-1)

    gan = GAN(input_shape=frame_shape, num_videos=1, k=1)
    gene = gan.build_generator()
    embe = gan.build_embedder()

    embe.load_weights('trained_models/0_meta_embedder_in_combined.h5')
    gene.load_weights(
        'trained_models/monalisa_fewshot_generator_in_combined.h5')

    style_embedding = embe.predict(style)
    style_embedding = style_embedding.repeat(8, axis=0)  # [512] -> [8, 512]
    fake_images = gene.predict([lndmk_images, style_embedding])
    return fake_images
Esempio n. 4
0
def main(hparams):
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = GAN(hparams)

    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    trainer = Trainer(gpus=1)

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model)
Esempio n. 5
0
File: main.py Progetto: zyzisyz/GANs
def main(args: Namespace) -> None:
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = GAN(args)

    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    # If use distubuted training  PyTorch recommends to use DistributedDataParallel.
    # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
    dm = MNISTDataModule.from_argparse_args(args)

    trainer = Trainer.from_argparse_args(args)

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model, dm)
Esempio n. 6
0
def main(argv):
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    if not os.path.exists(argv[1]):
        os.makedirs(argv[1])

    model = GAN(latent_dim=100, batch_size=128, device=device)
    model.move_to_device()
    model.G.load_state_dict(torch.load("./model/gan-G.pkl"))

    fname = os.path.join(argv[1], "fig1_2.jpg")

    model.save_image(fname)
def main():
    if not os.path.exists(os.path.join(config.tensorboard_dir, config.name)):
        os.makedirs(os.path.join(config.tensorboard_dir, config.name))
    if not os.path.exists(os.path.join(config.checkpoint_dir, config.name)):
        os.makedirs(os.path.join(config.checkpoint_dir, config.name))

    device = torch.device('cuda:0' if config.use_cuda else 'cpu')
    models = GAN(config).to(device)
    if config.load_epoch != 0:
        load_checkpoints(models, config.checkpoint_dir, config.name,
                         config.load_epoch)

    if config.is_train:
        models.train()
        writer = SummaryWriter(
            log_dir=os.path.join(config.tensorboard_dir, config.name))
        train(models, writer, device)
    else:
        models.eval()
        test(models, device)
Esempio n. 8
0
def load_gan(backbone_data=None,
             gan_weights='gan_weights.h5',
             backbone_weights='backbone_posttrained_weights.h5',
             generator_weights=None,
             discriminator_weights=None,
             clear_session=True):

    if (clear_session):
        keras.backend.clear_session()

    backbone = ResNet()
    backbone(backbone_data.get_test()[0])
    # generator = ResGen(backbone)

    discriminator = load_discriminator(
        data=backbone_data,
        clear_session=False,
        backbone_weights=backbone_weights,
        discriminator_weights=discriminator_weights)
    generator = load_generator(backbone_data=backbone_data,
                               clear_session=False,
                               backbone_weights=backbone_weights,
                               generator_weights=generator_weights)

    # x = discriminator(inp)
    # x = generator(x)
    # return  keras.model(inputs=[inp],outputs=[x])

    gan = GAN(generator=generator, discriminator=discriminator)

    if (gan_weights):
        input_shape = generator.get_input_shape()
        input_x, _ = get_gan_data(
            backbone_data.get_n_samples(10)[0], input_shape)
        gan.predict(input_x)
        gan.load_weights(gan_weights)

    return gan
def train(n_channels=3,
          resolution=32,
          z_dim=128,
          n_labels=0,
          lr=1e-3,
          e_drift=1e-3,
          wgp_target=750,
          initial_resolution=4,
          total_kimg=25000,
          training_kimg=500,
          transition_kimg=500,
          iters_per_checkpoint=500,
          n_checkpoint_images=16,
          glob_str='cifar10',
          out_dir='cifar10'):

    # instantiate logger
    logger = SummaryWriter(out_dir)

    # load data
    batch_size = MINIBATCH_OVERWRITES[0]
    train_iterator = iterate_minibatches(glob_str, batch_size, resolution)

    # build models
    G = Generator(n_channels, resolution, z_dim, n_labels)
    D = Discriminator(n_channels, resolution, n_labels)

    G_train, D_train = GAN(G, D, z_dim, n_labels, resolution, n_channels)

    D_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
    G_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)

    # define loss functions
    D_loss = [loss_mean, loss_gradient_penalty, 'mse']
    G_loss = [loss_wasserstein]

    # compile graphs used during training
    G.compile(G_opt, loss=loss_wasserstein)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=[1, GP_WEIGHT, e_drift])

    # for computing the loss
    ones = np.ones((batch_size, 1), dtype=np.float32)
    zeros = ones * 0.0

    # fix a z vector for training evaluation
    z_fixed = np.random.normal(0, 1, size=(n_checkpoint_images, z_dim))

    # vars
    resolution_log2 = int(np.log2(resolution))
    starting_block = resolution_log2
    starting_block -= np.floor(np.log2(initial_resolution))
    cur_block = starting_block
    cur_nimg = 0

    # compute duration of each phase and use proxy to update minibatch size
    phase_kdur = training_kimg + transition_kimg
    phase_idx_prev = 0

    # offset variable for transitioning between blocks
    offset = 0
    i = 0
    while cur_nimg < total_kimg * 1000:
        # block processing
        kimg = cur_nimg / 1000.0
        phase_idx = int(np.floor((kimg + transition_kimg) / phase_kdur))
        phase_idx = max(phase_idx, 0.0)
        phase_kimg = phase_idx * phase_kdur

        # update batch size and ones vector if we switched phases
        if phase_idx_prev < phase_idx:
            batch_size = MINIBATCH_OVERWRITES[phase_idx]
            train_iterator = iterate_minibatches(glob_str, batch_size)
            ones = np.ones((batch_size, 1), dtype=np.float32)
            zeros = ones * 0.0
            phase_idx_prev = phase_idx

        # possibly gradually update current level of detail
        if transition_kimg > 0 and phase_idx > 0:
            offset = (kimg + transition_kimg - phase_kimg) / transition_kimg
            offset = min(offset, 1.0)
            offset = offset + phase_idx - 1
            cur_block = max(starting_block - offset, 0.0)

        # update level of detail
        K.set_value(G_train.cur_block, np.float32(cur_block))
        K.set_value(D_train.cur_block, np.float32(cur_block))

        # train D
        for j in range(N_CRITIC_ITERS):
            z = np.random.normal(0, 1, size=(batch_size, z_dim))
            real_batch = next(train_iterator)
            fake_batch = G.predict_on_batch([z])
            interpolated_batch = get_interpolated_images(
                real_batch, fake_batch)
            losses_d = D_train.train_on_batch(
                [real_batch, fake_batch, interpolated_batch],
                [ones, ones * wgp_target, zeros])
            cur_nimg += batch_size

        # train G
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        loss_g = G_train.train_on_batch(z, -1 * ones)

        logger.add_scalar("cur_block", cur_block, i)
        logger.add_scalar("learning_rate", lr, i)
        logger.add_scalar("batch_size", z.shape[0], i)
        print("iter", i, "cur_block", cur_block, "lr", lr, "kimg", kimg,
              "losses_d", losses_d, "loss_g", loss_g)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_images = G.predict(z_fixed)
            # log fake images
            log_images(fake_images, 'fake', i, logger, fake_images.shape[1],
                       fake_images.shape[2], int(np.sqrt(n_checkpoint_images)))

            # plot real images for reference
            log_images(real_batch[:n_checkpoint_images], 'real', i, logger,
                       real_batch.shape[1], real_batch.shape[2],
                       int(np.sqrt(n_checkpoint_images)))

            # save the model to eventually resume training or do inference
            save_model(G, out_dir + "/model.json", out_dir + "/model.h5")

        log_losses(losses_d, loss_g, i, logger)
        i += 1
Esempio n. 10
0
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=BATCH_SIZE,
                                         shuffle=False)

print(f"""
Total training data: {len(trainset)}
Total testing data: {len(testset)}
Total data: {len(trainset) + len(testset)}
""",
      flush=True)

# 2. instantiate the network model
Z_DIM = 100
# reference: https://machinelearningmastery.com/how-to-train-stable-generative-adversarial-networks/
G = GAN.Generator(layers_dim=[Z_DIM, 256, 256, 784],
                  internal_activation=nn.ReLU(),
                  output_activation=nn.Tanh())
D = GAN.Discriminator(layers_dim=[784, 256, 256, 1],
                      internal_activation=nn.LeakyReLU(0.2),
                      output_activation=nn.Sigmoid())
print(f"""
Generator G:
{G}

Discriminator D:
{D}
""", flush=True)

# if we train with multiple GPUs, we need to use
# net.module.generate()
# with only one GPU, we can simply use
def fewshot_learn():
    metalearning_epoch = 0
    BATCH_SIZE = 1
    k = 1
    frame_shape = h, w, c = (256, 256, 3)
    input_embedder_shape = (h, w, k * c)
    BATCH_SIZE = 1
    num_videos = 1
    num_batches = 1
    epochs = 40
    dataname = 'monalisa'
    datapath = '../few-shot-learning-of-talking-heads/datasets/fewshot/' + dataname + '/lndmks'

    gan = GAN(input_shape=frame_shape, num_videos=num_videos, k=k)
    with tf.device("/cpu:0"):
        combined_to_train, combined, discriminator_to_train, discriminator = gan.compile_models(
            meta=False, gpus=0)
        embedder = gan.embedder
        generator = gan.generator
        intermediate_vgg19 = gan.intermediate_vgg19
        intermediate_vggface = gan.intermediate_vggface
        intermediate_discriminator = gan.intermediate_discriminator

    discriminator.load_weights(
        'trained_models/{}_meta_discriminator_weights.h5'.format(
            metalearning_epoch),
        by_name=True,
        skip_mismatch=True)
    combined.get_layer('embedder').load_weights(
        'trained_models/{}_meta_embedder_in_combined.h5'.format(
            metalearning_epoch))
    combined.get_layer('generator').load_weights(
        'trained_models/{}_meta_generator_in_combined.h5'.format(
            metalearning_epoch))

    for epoch in range(epochs):
        for batch_ix, (frames, landmarks, styles) in enumerate(
                flow_from_dir(datapath,
                              num_videos, (h, w),
                              BATCH_SIZE,
                              k,
                              meta=False)):
            if batch_ix == num_batches:
                break
            valid = np.ones((frames.shape[0], 1))
            invalid = -valid

            intermediate_vgg19_outputs = intermediate_vgg19.predict_on_batch(
                frames)
            intermediate_vggface_outputs = intermediate_vggface.predict_on_batch(
                frames)
            intermediate_discriminator_outputs = intermediate_discriminator.predict_on_batch(
                [frames, landmarks])

            style_list = [styles[:, i, :, :, :] for i in range(k)]
            embeddings_list = [
                embedder.predict_on_batch(style) for style in style_list
            ]
            average_embedding = np.mean(np.array(embeddings_list), axis=0)
            fake_frames = generator.predict_on_batch(
                [landmarks, average_embedding])

            g_loss = combined_to_train.train_on_batch(
                [landmarks] + style_list,
                intermediate_vgg19_outputs + intermediate_vggface_outputs +
                [valid] + intermediate_discriminator_outputs)

            embeddings_list = [
                embedder.predict_on_batch(style) for style in style_list
            ]
            average_embedding = np.mean(np.array(embeddings_list), axis=0)

            d_loss_real = discriminator_to_train.train_on_batch(
                [frames, landmarks, average_embedding], [valid])

            fake_frames = generator.predict_on_batch(
                [landmarks, average_embedding])

            d_loss_fake = discriminator_to_train.train_on_batch(
                [fake_frames, landmarks, average_embedding], [invalid])
            logger.info((epoch, batch_ix, g_loss, (d_loss_real, d_loss_fake)))

    # Save whole model


#    combined.save('trained_models/{}_fewshot_combined.h5'.format(dataname))
#    discriminator.save('trained_models/{}_fewshot_discriminator.h5'.format(dataname))

# Save weights only
#    combined.save_weights('trained_models/{}_fewshot_combined_weights.h5'.format(dataname))
    combined.get_layer('generator').save_weights(
        'trained_models/{}_fewshot_generator_in_combined.h5'.format(dataname))
    #    combined.get_layer('embedder').save_weights('trained_models/{}_fewshot_embedder_in_combined.h5'.format(dataname))
    discriminator.save_weights(
        'trained_models/{}_fewshot_discriminator_weights.h5'.format(dataname))
Esempio n. 12
0
import settings
import torch
from torch.optim import lr_scheduler
from models import GAN, Detector
from data import get_dataloader
import itertools
from utils import *

GAN_model = GAN()
Detector_model = Detector()
dataset = get_dataloader(batch_size=settings.batch_size,
                         num_workers=settings.num_workers)

beta1 = 0.5
gan_lr = 0.0002
detector_lr = 1e-3

optimizer_G = torch.optim.Adam(itertools.chain(GAN_model.netG_A.parameters(),
                                               GAN_model.netG_B.parameters()),
                               lr=gan_lr,
                               betas=(beta1, 0.999))
optimizer_D = torch.optim.Adam(itertools.chain(GAN_model.netD_A.parameters(),
                                               GAN_model.netD_B.parameters(),
                                               GAN_model.localD.parameters()),
                               lr=gan_lr,
                               betas=(beta1, 0.999))
optimizer_Detector = torch.optim.Adam(itertools.chain(
    Detector_model.backbone, Detector_model.detector16,
    Detector_model.detector8),
                                      lr=detector_lr)
Esempio n. 13
0
File: main.py Progetto: zyzisyz/GANs
    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    # If use distubuted training  PyTorch recommends to use DistributedDataParallel.
    # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
    dm = MNISTDataModule.from_argparse_args(args)

    trainer = Trainer.from_argparse_args(args)

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model, dm)


if __name__ == '__main__':
    parser = ArgumentParser()

    # Add program level args, if any.
    # ------------------------
    # Add LightningDataLoader args
    parser = MNISTDataModule.add_argparse_args(parser)
    # Add model specific args
    parser = GAN.add_argparse_args(parser)
    # Add trainer args
    parser = Trainer.add_argparse_args(parser)
    # Parse all arguments
    args = parser.parse_args()

    main(args)
Esempio n. 14
0
if __name__ == '__main__':
    imageDir = '...celeb_images/Part 1'

    image_path_list = []
    for file in os.listdir(imageDir):
        image_path_list.append(os.path.join(imageDir, file))

    image = np.empty([len(image_path_list), 64, 64, 3])
    for indx, imagePath in enumerate(image_path_list):
        if(indx>5000):
            break
        im = Image.open(imagePath).convert('RGB')
        im = im.resize((64, 64))
        im=np.array(im,dtype=np.float32)
        im=im/255
        im=Normalize(im, [0.5,0.5,0.5], [0.5,0.5,0.5])
        image[indx,:,:,:] = im

    image = image[:5000,:,:,:]


    gan = GAN()
    gen, dis = gan.train(image,200)

    noise = Variable(torch.randn(1, 100, 1, 1)).float().cuda()
    im = gen.feed_forward(noise)
    im = im.permute(0, 2, 3, 1)
    im = im.squeeze(0).cpu().detach().numpy()*255
    im = DeNormalize(im,[0.5,0.5,0.5],[0.5,0.5,0.5])

    plt.imshow(im)
Esempio n. 15
0
            elif lambdas == 'naive':
                lambda0, lambda1 = 'naive', lambda1_
            example_dir = './trained_gan/{}_{}_GAN_{}_{}'.format(
                dataset, function, lambda0, lambda1)

            list_div = []
            list_qa = []
            list_oa = []
            for i in range(10):
                directory = example_dir + '/{}'.format(i)
                npy_path = directory + '/scores.npy'
                if os.path.exists(npy_path):
                    div, qa, oa = np.load(npy_path)
                else:
                    # Generated data
                    model = GAN(2, 2, lambda0, lambda1)
                    model.restore(save_dir=directory)
                    gen_data = model.synthesize(N)
                    # Compute metrics
                    div = diversity_score(gen_data, subset_size, sample_times)
                    qa = quality_score(gen_data, func_obj)
                    oa = overall_score(gen_data, func_obj)
                    np.save(npy_path, [div, qa, oa])
                list_div.append(div)
                list_qa.append(qa)
                list_oa.append(oa)
            list_div_lambdas.append(list_div)
            list_qa_lambdas.append(list_qa)
            list_oa_lambdas.append(list_oa)

        fig = plt.figure(figsize=(15, 5))
Esempio n. 16
0
import torch.nn as nn
import torchvision
from datetime import datetime
import os
import models.GAN as GAN
from collections import OrderedDict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load the model
MODEL_PATH = os.path.dirname(os.path.realpath(__file__)) + '/../saved_models/'
MODEL_NAME = 'gan-model-epoch100.pth'
Z_DIM = 100

G = GAN.Generator(layers_dim=[100, 256, 256, 784],
                  internal_activation=nn.ReLU(),
                  output_activation=nn.Tanh())

checkpoint = torch.load(MODEL_PATH + MODEL_NAME, map_location=device)
old_G_state_dict = checkpoint.get('G_state_dict')

if 'module.' in list(old_G_state_dict.keys())[0]:
    new_G_state_dict = OrderedDict()
    for key, value in old_G_state_dict.items():
        name = key[7:]
        new_G_state_dict[name] = value
    G.load_state_dict(new_G_state_dict)
else:
    G.load_state_dict(old_G_state_dict)

sample = torch.randn((1, Z_DIM))
Esempio n. 17
0
def main():
    model = GAN()
    trainer = Trainer()
    trainer.fit(model)
Esempio n. 18
0
def meta_learn():
    k = 8
    frame_shape = h, w, c = (256, 256, 3)
    input_embedder_shape = (h, w, k * c)
    BATCH_SIZE = 12
    num_videos = 145008  # This is dividable by BATCH_SIZE. All data is 145520
    num_batches = num_videos // BATCH_SIZE
    epochs = 75
    datapath = '../few-shot-learning-of-talking-heads/datasets/voxceleb2-9f/train/lndmks'

    gan = GAN(input_shape=frame_shape, num_videos=num_videos, k=k)
    with tf.device("/cpu:0"):
        combined_to_train, combined, discriminator_to_train, discriminator = gan.compile_models(
            meta=True, gpus=0)
        embedder = gan.embedder
        generator = gan.generator
        intermediate_vgg19 = gan.intermediate_vgg19
        intermediate_vggface = gan.intermediate_vggface
        intermediate_discriminator = gan.intermediate_discriminator
        embedding_discriminator = gan.embedding_discriminator

    logger.info('==== discriminator ===')
    discriminator.summary(print_fn=logger.info)
    logger.info('=== generator ===')
    combined.get_layer('generator').summary(print_fn=logger.info)
    logger.info('=== embedder ===')
    combined.get_layer('embedder').summary(print_fn=logger.info)
    combined.summary(print_fn=logger.info)

    for epoch in range(epochs):
        logger.info(('Epoch: ', epoch))
        for batch_ix, (frames, landmarks, styles, condition) in enumerate(
                flow_from_dir(datapath, num_videos, (h, w), BATCH_SIZE, k)):
            if batch_ix == num_batches:
                break
            valid = np.ones((frames.shape[0], 1))
            invalid = -valid

            intermediate_vgg19_reals = intermediate_vgg19.predict_on_batch(
                frames)
            intermediate_vggface_reals = intermediate_vggface.predict_on_batch(
                frames)
            intermediate_discriminator_reals = intermediate_discriminator.predict_on_batch(
                [frames, landmarks])

            style_list = [styles[:, i, :, :, :] for i in range(k)]

            w_i = embedding_discriminator.predict_on_batch(condition)

            g_loss = combined_to_train.train_on_batch(
                [landmarks] + style_list + [condition],
                intermediate_vgg19_reals + intermediate_vggface_reals +
                [valid] + intermediate_discriminator_reals + [w_i] * k)

            d_loss_real = discriminator_to_train.train_on_batch(
                [frames, landmarks, condition], [valid])

            embeddings_list = [
                embedder.predict_on_batch(style) for style in style_list
            ]
            average_embedding = np.mean(np.array(embeddings_list), axis=0)
            fake_frames = generator.predict_on_batch(
                [landmarks, average_embedding])
            d_loss_fake = discriminator_to_train.train_on_batch(
                [fake_frames, landmarks, condition], [invalid])
            logger.info((epoch, batch_ix, g_loss, (d_loss_real, d_loss_fake)))

            if batch_ix % 100 == 0 and batch_ix > 0:
                # Save whole model
                # combined.save('trained_models/{}_meta_combined.h5'.format(epoch))
                # discriminator.save('trained_models/{}_meta_discriminator.h5'.format(epoch))

                # Save weights only
                #                combined.save_weights('trained_models/{}_meta_combined_weights.h5'.format(epoch))
                combined.get_layer('generator').save_weights(
                    'trained_models/{}_meta_generator_in_combined.h5'.format(
                        epoch))
                combined.get_layer('embedder').save_weights(
                    'trained_models/{}_meta_embedder_in_combined.h5'.format(
                        epoch))
                discriminator.save_weights(
                    'trained_models/{}_meta_discriminator_weights.h5'.format(
                        epoch))
                logger.info(
                    'Checkpoint saved at Epoch: {}; batch_ix: {}'.format(
                        epoch, batch_ix))
        print()
Esempio n. 19
0
def train():
    parser = argparse.ArgumentParser(description="keras pix2pix")
    parser.add_argument('--batchsize', '-b', type=int, default=1)
    parser.add_argument('--patchsize', '-p', type=int, default=64)
    parser.add_argument('--epoch', '-e', type=int, default=500)
    parser.add_argument('--out', '-o', default='result')
    parser.add_argument('--lmd', '-l', type=int, default=100)
    parser.add_argument('--dark', '-d', type=float, default=0.01)
    parser.add_argument('--gpu', '-g', type=int, default=2)
    args = parser.parse_args()
    args = parser.parse_args()
    PATCH_SIZE = args.patchsize
    BATCH_SIZE = args.batchsize
    epoch = args.epoch
    lmd = args.lmd

    # set gpu environment
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    K.set_session(sess)

    # make directory to save results
    if not os.path.exists("./result"):
        os.mkdir("./result")
    resultDir = "./result/" + args.out
    modelDir = resultDir + "/model/"
    if not os.path.exists(resultDir):
        os.mkdir(resultDir)
    if not os.path.exists(modelDir):
        os.mkdir(modelDir)

    # make a logfile and add colnames
    o = open(resultDir + "/log.txt", "w")
    o.write("batch:" + str(BATCH_SIZE) + "  lambda:" + str(lmd) + "\n")
    o.write(
        "epoch,dis_loss,gan_mae,gan_entropy,vdis_loss,vgan_mae,vgan_entropy" +
        "\n")
    o.close()

    # load data
    ds1_first, ds1_last, num_ds1 = 1, 1145, 1145
    ds2_first, ds2_last, num_ds2 = 2000, 6749, 4750
    # ds1_first, ds1_last, num_ds1 = 1,    100, 100
    # ds2_first, ds2_last, num_ds2 = 101, 200, 100
    train_data_i = np.concatenate([
        np.arange(ds1_first, ds1_last + 1)[:int(num_ds1 * 0.7)],
        np.arange(ds2_first, ds2_last + 1)[:int(num_ds2 * 0.7)]
    ])
    test_data_i = np.concatenate([
        np.arange(ds1_first, ds1_last + 1)[int(num_ds1 * 0.7):],
        np.arange(ds2_first, ds2_last + 1)[int(num_ds2 * 0.7):]
    ])
    train_gt, _, train_night = load_dataset(data_range=train_data_i,
                                            dark=args.dark)
    test_gt, _, test_night = load_dataset(data_range=test_data_i,
                                          dark=args.dark)

    # Create optimizers
    opt_Gan = Adam(lr=1E-3)
    opt_Discriminator = Adam(lr=1E-3)
    opt_Generator = Adam(lr=1E-3)

    # set the loss of gan
    def dis_entropy(y_true, y_pred):
        return -K.log(K.abs((y_pred - y_true)) + 1e-07)

    gan_loss = ['mae', dis_entropy]
    gan_loss_weights = [lmd, 1]

    # make models
    Generator = generator()
    Generator.compile(loss='mae', optimizer=opt_Generator)
    Discriminator = discriminator()
    Discriminator.trainable = False
    Gan = GAN(Generator, Discriminator)
    Gan.compile(loss=gan_loss,
                loss_weights=gan_loss_weights,
                optimizer=opt_Gan)
    Discriminator.trainable = True
    Discriminator.compile(loss=dis_entropy, optimizer=opt_Discriminator)

    # start training
    n_train = train_gt.shape[0]
    n_test = test_gt.shape[0]
    print(n_train, n_test)
    p = ProgressBar()
    for epoch in p(range(epoch)):
        p.update(epoch + 1)
        out_file = open(resultDir + "/log.txt", "a")
        train_ind = np.random.permutation(n_train)
        test_ind = np.random.permutation(n_test)
        dis_losses = []
        gan_losses = []
        test_dis_losses = []
        test_gan_losses = []
        y_real = np.array([1] * BATCH_SIZE)
        y_fake = np.array([0] * BATCH_SIZE)
        y_gan = np.array([1] * BATCH_SIZE)

        # training
        for batch_i in range(int(n_train / BATCH_SIZE)):
            gt_batch = train_gt[train_ind[(batch_i *
                                           BATCH_SIZE):((batch_i + 1) *
                                                        BATCH_SIZE)], :, :, :]
            night_batch = train_night[train_ind[(
                batch_i * BATCH_SIZE):((batch_i + 1) * BATCH_SIZE)], :, :, :]
            generated_batch = Generator.predict(night_batch)
            # train Discriminator
            dis_real_loss = np.array(
                Discriminator.train_on_batch([night_batch, gt_batch], y_real))
            dis_fake_loss = np.array(
                Discriminator.train_on_batch([night_batch, generated_batch],
                                             y_fake))
            dis_loss_batch = (dis_real_loss + dis_fake_loss) / 2
            dis_losses.append(dis_loss_batch)
            gan_loss_batch = np.array(
                Gan.train_on_batch(night_batch, [gt_batch, y_gan]))
            gan_losses.append(gan_loss_batch)
        dis_loss = np.mean(np.array(dis_losses))
        gan_loss = np.mean(np.array(gan_losses), axis=0)

        # validation
        for batch_i in range(int(n_test / BATCH_SIZE)):
            gt_batch = test_gt[test_ind[(batch_i *
                                         BATCH_SIZE):((batch_i + 1) *
                                                      BATCH_SIZE)], :, :, :]
            night_batch = test_night[test_ind[(
                batch_i * BATCH_SIZE):((batch_i + 1) * BATCH_SIZE)], :, :, :]
            generated_batch = Generator.predict(night_batch)
            # train Discriminator
            dis_real_loss = np.array(
                Discriminator.test_on_batch([night_batch, gt_batch], y_real))
            dis_fake_loss = np.array(
                Discriminator.test_on_batch([night_batch, generated_batch],
                                            y_fake))
            test_dis_loss_batch = (dis_real_loss + dis_fake_loss) / 2
            test_dis_losses.append(test_dis_loss_batch)
            test_gan_loss_batch = np.array(
                Gan.test_on_batch(night_batch, [gt_batch, y_gan]))
            test_gan_losses.append(test_gan_loss_batch)
        test_dis_loss = np.mean(np.array(test_dis_losses))
        test_gan_loss = np.mean(np.array(gan_losses), axis=0)
        # write log of leaning
        out_file.write(
            str(epoch) + "," + str(dis_loss) + "," + str(gan_loss[1]) + "," +
            str(gan_loss[2]) + "," + str(test_dis_loss) + "," +
            str(test_gan_loss[1]) + "," + str(test_gan_loss[2]) + "\n")

        # visualize
        if epoch % 50 == 0:
            # for training data
            gt_batch = train_gt[train_ind[0:9], :, :, :]
            night_batch = train_night[train_ind[0:9], :, :, :]
            generated_batch = Generator.predict(night_batch)
            save_images(night_batch,
                        resultDir + "/label_" + str(epoch) + "epoch.png")
            save_images(gt_batch,
                        resultDir + "/gt_" + str(epoch) + "epoch.png")
            save_images(generated_batch,
                        resultDir + "/generated_" + str(epoch) + "epoch.png")
            # for validation data
            gt_batch = test_gt[test_ind[0:9], :, :, :]
            night_batch = test_night[test_ind[0:9], :, :, :]
            generated_batch = Generator.predict(night_batch)
            save_images(night_batch,
                        resultDir + "/vlabel_" + str(epoch) + "epoch.png")
            save_images(gt_batch,
                        resultDir + "/vgt_" + str(epoch) + "epoch.png")
            save_images(generated_batch,
                        resultDir + "/vgenerated_" + str(epoch) + "epoch.png")

            Gan.save_weights(modelDir + 'gan_weights' + "_lambda" + str(lmd) +
                             "_epoch" + str(epoch) + '.h5')

        out_file.close()
    out_file.close()