def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.PRO_GAN import ConditionalProGAN

    #print(args.config)
    config = get_config(args.config)
    #print("Current Configuration:", config)

    print("Create dataset...")
    # create the dataset for training
    if config.use_pretrained_encoder:
        print("Using PretrainedEncoder...")
        if not os.path.exists(
                f"text_encoder_{config.tensorboard_comment}.pickle"):

            print("Creating new vocab and dataset pickle files ...")
            dataset = dl.RawTextFace2TextDataset(
                data_path=config.data_path,
                img_dir=config.images_dir,
                img_transform=dl.get_transform(config.img_dims))
            val_dataset = dl.RawTextFace2TextDataset(
                data_path=config.data_path_val,
                img_dir=config.val_images_dir,  # unnecessary
                img_transform=dl.get_transform(config.img_dims))
            from networks.TextEncoder import PretrainedEncoder
            # create a new session object for the pretrained encoder:
            text_encoder = PretrainedEncoder(
                model_file=config.pretrained_encoder_file,
                embedding_file=config.pretrained_embedding_file,
                device=device)
            encoder_optim = None
            print("Pickling dataset, val_dataset and text_encoder....")
            with open(f'dataset_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
            with open(f'val_dataset_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(val_dataset,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(f'text_encoder_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(text_encoder,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        else:
            print("Loading dataset, val_dataset and text_encoder from file...")
            with open(f'val_dataset_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                val_dataset = pickle.load(handle)
            with open(f'dataset_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                dataset = pickle.load(handle)
            from networks.TextEncoder import PretrainedEncoder
            with open(f'text_encoder_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                text_encoder = pickle.load(handle)
            encoder_optim = None
    else:
        print("Using Face2TextDataset dataloader...")
        dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                      img_dir=config.images_dir,
                                      img_transform=dl.get_transform(
                                          config.img_dims),
                                      captions_len=config.captions_length)
        text_encoder = Encoder(embedding_size=config.embedding_size,
                               vocab_size=dataset.vocab_size,
                               hidden_size=config.hidden_size,
                               num_layers=config.num_layers,
                               device=device)
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.beta_1, config.beta_2),
                                      eps=config.eps)

    # create the networks

    if args.encoder_file is not None:
        # Note this should not be used with the pretrained encoder file
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             use_eql=config.use_eql,
                                             device=device)

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))
    print("Create cprogan...")
    c_pro_gan = ConditionalProGAN(
        embedding_size=config.hidden_size,
        depth=config.depth,
        latent_size=config.latent_size,
        compressed_latent_size=config.compressed_latent_size,
        learning_rate=config.learning_rate,
        beta_1=config.beta_1,
        beta_2=config.beta_2,
        eps=config.eps,
        drift=config.drift,
        n_critic=config.n_critic,
        use_eql=config.use_eql,
        loss=config.loss_function,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    #print("Generator Config:")
    print(c_pro_gan.gen)

    #print("\nDiscriminator Config:")
    #print(c_pro_gan.dis)

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        c_pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        c_pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Create optimizer...")
    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.beta_1, config.beta_2),
                             eps=config.eps)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        c_pro_gan=c_pro_gan,
        dataset=dataset,
        validation_dataset=val_dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
        comment=config.tensorboard_comment,
        use_matching_aware_dis=config.use_matching_aware_discriminator)
Example #2
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.PRO_GAN import ConditionalProGAN

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    if config.use_pretrained_encoder:
        dataset = dl.RawTextFace2TextDataset(
            annots_file=config.annotations_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims))
        from networks.TextEncoder import PretrainedEncoder
        # create a new session object for the pretrained encoder:
        sess_config = tf.ConfigProto(device_count={"GPU": 0})
        session = tf.Session(config=sess_config)
        text_encoder = PretrainedEncoder(
            session=session,
            module_dir=config.pretrained_encoder_dir,
            download=config.download_pretrained_encoder)
        encoder_optim = None
    else:
        dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                      img_dir=config.images_dir,
                                      img_transform=dl.get_transform(
                                          config.img_dims),
                                      captions_len=config.captions_length)
        text_encoder = Encoder(embedding_size=config.embedding_size,
                               vocab_size=dataset.vocab_size,
                               hidden_size=config.hidden_size,
                               num_layers=config.num_layers,
                               device=device)
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.beta_1, config.beta_2),
                                      eps=config.eps)

    # create the networks

    if args.encoder_file is not None:
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             use_eql=config.use_eql,
                                             device=device)

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    c_pro_gan = ConditionalProGAN(
        embedding_size=config.hidden_size,
        depth=config.depth,
        latent_size=config.latent_size,
        compressed_latent_size=config.compressed_latent_size,
        learning_rate=config.learning_rate,
        beta_1=config.beta_1,
        beta_2=config.beta_2,
        eps=config.eps,
        drift=config.drift,
        n_critic=config.n_critic,
        use_eql=config.use_eql,
        loss=config.loss_function,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        c_pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        c_pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.beta_1, config.beta_2),
                             eps=config.eps)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        c_pro_gan=c_pro_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
        use_matching_aware_dis=config.use_matching_aware_discriminator)
def homepage_result():
    caption = request.form["des"]
    current_depth = 4

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.C_PRO_GAN import ProGAN

    # define the device for the training script
    device = th.device("cuda" if th.cuda.is_available() else "cpu")

    ############################################################################
    # load my generator.

    def get_config(conf_file):
        """
        parse and load the provided configuration
        :param conf_file: configuration file
        :return: conf => parsed configuration
        """
        from easydict import EasyDict as edict

        with open(conf_file, "r") as file_descriptor:
            data = yaml.load(file_descriptor)

        # convert the data into an easyDictionary
        return edict(data)

    config = get_config("configs\\11.conf")

    c_pro_gan = ProGAN(embedding_size=config.hidden_size,
                       depth=config.depth,
                       latent_size=config.latent_size,
                       learning_rate=config.learning_rate,
                       beta_1=config.beta_1,
                       beta_2=config.beta_2,
                       eps=config.eps,
                       drift=config.drift,
                       n_critic=config.n_critic,
                       device=device)

    c_pro_gan.gen.load_state_dict(
        th.load("training_runs\\11\\saved_models\\GAN_GEN_3_20.pth"))

    ###################################################################################
    # load my embedding and conditional augmentor

    dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                  img_dir=config.images_dir,
                                  img_transform=dl.get_transform(
                                      config.img_dims),
                                  captions_len=config.captions_length)

    text_encoder = Encoder(embedding_size=config.embedding_size,
                           vocab_size=dataset.vocab_size,
                           hidden_size=config.hidden_size,
                           num_layers=config.num_layers,
                           device=device)
    text_encoder.load_state_dict(
        th.load("training_runs\\11\\saved_models\\Encoder_3_20.pth"))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             device=device)
    condition_augmenter.load_state_dict(
        th.load(
            "training_runs\\11\\saved_models\\Condition_Augmentor_3_20.pth"))

    ###################################################################################
    # #ask for text description/caption

    # caption to text encoding
    #caption = input('Enter your desired description : ')
    seq = []
    for word in caption.split():
        seq.append(dataset.rev_vocab[word])
    for i in range(len(seq), 100):
        seq.append(0)

    seq = th.LongTensor(seq)
    seq = seq.cuda()
    print(type(seq))
    print('\nInput : ', caption)

    list_seq = [seq for i in range(16)]
    print(len(list_seq))
    list_seq = th.stack(list_seq)
    list_seq = list_seq.cuda()

    embeddings = text_encoder(list_seq)

    c_not_hats, mus, sigmas = condition_augmenter(embeddings)

    z = th.randn(list_seq.shape[0],
                 c_pro_gan.latent_size - c_not_hats.shape[-1]).to(device)

    gan_input = th.cat((c_not_hats, z), dim=-1)

    alpha = 0.007352941176470588

    samples = c_pro_gan.gen(gan_input, current_depth, alpha)

    from torchvision.utils import save_image
    from torch.nn.functional import upsample
    # from train_network import create_grid

    img_file = "static\\" + caption + '.png'
    samples = (samples / 2) + 0.5
    if int(np.power(2, c_pro_gan.depth - current_depth - 1)) > 1:
        samples = upsample(samples, scale_factor=current_depth)

    # save image to the disk, the resulting image is <caption>.png
    save_image(samples, img_file, nrow=int(np.sqrt(20)))

    ###################################################################################
    # #output the image.

    result = "\\static\\" + caption + ".png"
    return render_template("main.html",
                           result_img=result,
                           result_caption=caption)
Example #4
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.C_PRO_GAN import ProGAN

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                  img_dir=config.images_dir,
                                  img_transform=dl.get_transform(
                                      config.img_dims),
                                  captions_len=config.captions_length)

    # create the networks
    text_encoder = Encoder(embedding_size=config.embedding_size,
                           vocab_size=dataset.vocab_size,
                           hidden_size=config.hidden_size,
                           num_layers=config.num_layers,
                           device=device)

    if args.encoder_file is not None:
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             device=device)

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    c_pro_gan = ProGAN(embedding_size=config.hidden_size,
                       depth=config.depth,
                       latent_size=config.latent_size,
                       learning_rate=config.learning_rate,
                       beta_1=config.beta_1,
                       beta_2=config.beta_2,
                       eps=config.eps,
                       drift=config.drift,
                       n_critic=config.n_critic,
                       device=device)

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        c_pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        c_pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizers for Encoder and Condition Augmenter separately
    encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                  lr=config.learning_rate,
                                  betas=(config.beta_1, config.beta_2),
                                  eps=config.eps)

    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.beta_1, config.beta_2),
                             eps=config.eps)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        c_pro_gan=c_pro_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
        use_matching_aware_dis=config.use_matching_aware_discriminator)
Example #5
0
                   learning_rate=config.learning_rate,
                   beta_1=config.beta_1,
                   beta_2=config.beta_2,
                   eps=config.eps,
                   drift=config.drift,
                   n_critic=config.n_critic,
                   device=device)

c_pro_gan.gen.load_state_dict(
    th.load("training_runs\\11\\saved_models\\GAN_GEN_4.pth"))

###################################################################################
#load my embedding and conditional augmentor

dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                              img_dir=config.images_dir,
                              img_transform=dl.get_transform(config.img_dims),
                              captions_len=config.captions_length)

text_encoder = Encoder(embedding_size=config.embedding_size,
                       vocab_size=dataset.vocab_size,
                       hidden_size=config.hidden_size,
                       num_layers=config.num_layers,
                       device=device)
text_encoder.load_state_dict(
    th.load("training_runs\\11\\saved_models\\Encoder_3_20.pth"))

condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                         latent_size=config.ca_out_size,
                                         device=device)
condition_augmenter.load_state_dict(
    th.load("training_runs\\11\\saved_models\\Condition_Augmentor_3_20.pth"))
Example #6
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    #from pro_gan_pytorch.PRO_GAN import ConditionalProGAN
    from MSG_GAN.GAN import MSG_GAN
    from MSG_GAN import Losses as lses

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    if config.use_pretrained_encoder:
        dataset = dl.RawTextFace2TextDataset(
            annots_file=config.annotations_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims)
        )
        from networks.TextEncoder import PretrainedEncoder
        # create a new session object for the pretrained encoder:
        text_encoder = PretrainedEncoder(
            model_file=config.pretrained_encoder_file,
            embedding_file=config.pretrained_embedding_file,
            device=device
        )
        encoder_optim = None
    else:
        dataset = dl.Face2TextDataset(
            pro_pick_file=config.processed_text_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims),
            captions_len=config.captions_length
        )
        text_encoder = Encoder(
            embedding_size=config.embedding_size,
            vocab_size=dataset.vocab_size,
            hidden_size=config.hidden_size,
            num_layers=config.num_layers,
            device=device
        )
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.adam_beta1, config.adam_beta2),
                                      eps=config.eps)
    msg_gan = MSG_GAN(
        depth=config.depth,
        latent_size=config.latent_size,
        use_eql=config.use_eql,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    genoptim = th.optim.Adam(msg_gan.gen.parameters(), config.g_lr,
                              [config.adam_beta1, config.adam_beta2])

    disoptim = th.optim.Adam(msg_gan.dis.parameters(), config.d_lr,
                              [config.adam_beta1, config.adam_beta2])

    loss = lses.RelativisticAverageHingeGAN

    # create the networks

    if args.encoder_file is not None:
        # Note this should not be used with the pretrained encoder file
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(
        input_size=config.hidden_size,
        latent_size=config.ca_out_size,
        use_eql=config.use_eql,
        device=device
    )

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.adam_beta1, config.adam_beta2),
                             eps=config.eps)

    print("Generator Config:")
    print(msg_gan.gen)

    print("\nDiscriminator Config:")
    print(msg_gan.dis)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        msg_gan=msg_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        gen_optim=genoptim,
        dis_optim=disoptim,
        loss_fn=loss(msg_gan.dis),
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
    )