예제 #1
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.PRO_GAN import ConditionalProGAN

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    if config.use_pretrained_encoder:
        dataset = dl.RawTextFace2TextDataset(
            annots_file=config.annotations_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims))
        from networks.TextEncoder import PretrainedEncoder
        # create a new session object for the pretrained encoder:
        sess_config = tf.ConfigProto(device_count={"GPU": 0})
        session = tf.Session(config=sess_config)
        text_encoder = PretrainedEncoder(
            session=session,
            module_dir=config.pretrained_encoder_dir,
            download=config.download_pretrained_encoder)
        encoder_optim = None
    else:
        dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                      img_dir=config.images_dir,
                                      img_transform=dl.get_transform(
                                          config.img_dims),
                                      captions_len=config.captions_length)
        text_encoder = Encoder(embedding_size=config.embedding_size,
                               vocab_size=dataset.vocab_size,
                               hidden_size=config.hidden_size,
                               num_layers=config.num_layers,
                               device=device)
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.beta_1, config.beta_2),
                                      eps=config.eps)

    # create the networks

    if args.encoder_file is not None:
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             use_eql=config.use_eql,
                                             device=device)

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    c_pro_gan = ConditionalProGAN(
        embedding_size=config.hidden_size,
        depth=config.depth,
        latent_size=config.latent_size,
        compressed_latent_size=config.compressed_latent_size,
        learning_rate=config.learning_rate,
        beta_1=config.beta_1,
        beta_2=config.beta_2,
        eps=config.eps,
        drift=config.drift,
        n_critic=config.n_critic,
        use_eql=config.use_eql,
        loss=config.loss_function,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        c_pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        c_pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.beta_1, config.beta_2),
                             eps=config.eps)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        c_pro_gan=c_pro_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
        use_matching_aware_dis=config.use_matching_aware_discriminator)
예제 #2
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    from networks.PRO_GAN import ConditionalProGAN

    #print(args.config)
    config = get_config(args.config)
    #print("Current Configuration:", config)

    print("Create dataset...")
    # create the dataset for training
    if config.use_pretrained_encoder:
        print("Using PretrainedEncoder...")
        if not os.path.exists(
                f"text_encoder_{config.tensorboard_comment}.pickle"):

            print("Creating new vocab and dataset pickle files ...")
            dataset = dl.RawTextFace2TextDataset(
                data_path=config.data_path,
                img_dir=config.images_dir,
                img_transform=dl.get_transform(config.img_dims))
            val_dataset = dl.RawTextFace2TextDataset(
                data_path=config.data_path_val,
                img_dir=config.val_images_dir,  # unnecessary
                img_transform=dl.get_transform(config.img_dims))
            from networks.TextEncoder import PretrainedEncoder
            # create a new session object for the pretrained encoder:
            text_encoder = PretrainedEncoder(
                model_file=config.pretrained_encoder_file,
                embedding_file=config.pretrained_embedding_file,
                device=device)
            encoder_optim = None
            print("Pickling dataset, val_dataset and text_encoder....")
            with open(f'dataset_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
            with open(f'val_dataset_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(val_dataset,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(f'text_encoder_{config.tensorboard_comment}.pickle',
                      'wb') as handle:
                pickle.dump(text_encoder,
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        else:
            print("Loading dataset, val_dataset and text_encoder from file...")
            with open(f'val_dataset_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                val_dataset = pickle.load(handle)
            with open(f'dataset_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                dataset = pickle.load(handle)
            from networks.TextEncoder import PretrainedEncoder
            with open(f'text_encoder_{config.tensorboard_comment}.pickle',
                      'rb') as handle:
                text_encoder = pickle.load(handle)
            encoder_optim = None
    else:
        print("Using Face2TextDataset dataloader...")
        dataset = dl.Face2TextDataset(pro_pick_file=config.processed_text_file,
                                      img_dir=config.images_dir,
                                      img_transform=dl.get_transform(
                                          config.img_dims),
                                      captions_len=config.captions_length)
        text_encoder = Encoder(embedding_size=config.embedding_size,
                               vocab_size=dataset.vocab_size,
                               hidden_size=config.hidden_size,
                               num_layers=config.num_layers,
                               device=device)
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.beta_1, config.beta_2),
                                      eps=config.eps)

    # create the networks

    if args.encoder_file is not None:
        # Note this should not be used with the pretrained encoder file
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(input_size=config.hidden_size,
                                             latent_size=config.ca_out_size,
                                             use_eql=config.use_eql,
                                             device=device)

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))
    print("Create cprogan...")
    c_pro_gan = ConditionalProGAN(
        embedding_size=config.hidden_size,
        depth=config.depth,
        latent_size=config.latent_size,
        compressed_latent_size=config.compressed_latent_size,
        learning_rate=config.learning_rate,
        beta_1=config.beta_1,
        beta_2=config.beta_2,
        eps=config.eps,
        drift=config.drift,
        n_critic=config.n_critic,
        use_eql=config.use_eql,
        loss=config.loss_function,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    #print("Generator Config:")
    print(c_pro_gan.gen)

    #print("\nDiscriminator Config:")
    #print(c_pro_gan.dis)

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        c_pro_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        c_pro_gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Create optimizer...")
    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.beta_1, config.beta_2),
                             eps=config.eps)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        c_pro_gan=c_pro_gan,
        dataset=dataset,
        validation_dataset=val_dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
        comment=config.tensorboard_comment,
        use_matching_aware_dis=config.use_matching_aware_discriminator)
예제 #3
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    from networks.TextEncoder import Encoder
    from networks.ConditionAugmentation import ConditionAugmentor
    #from pro_gan_pytorch.PRO_GAN import ConditionalProGAN
    from MSG_GAN.GAN import MSG_GAN
    from MSG_GAN import Losses as lses

    print(args.config)
    config = get_config(args.config)
    print("Current Configuration:", config)

    # create the dataset for training
    if config.use_pretrained_encoder:
        dataset = dl.RawTextFace2TextDataset(
            annots_file=config.annotations_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims)
        )
        from networks.TextEncoder import PretrainedEncoder
        # create a new session object for the pretrained encoder:
        text_encoder = PretrainedEncoder(
            model_file=config.pretrained_encoder_file,
            embedding_file=config.pretrained_embedding_file,
            device=device
        )
        encoder_optim = None
    else:
        dataset = dl.Face2TextDataset(
            pro_pick_file=config.processed_text_file,
            img_dir=config.images_dir,
            img_transform=dl.get_transform(config.img_dims),
            captions_len=config.captions_length
        )
        text_encoder = Encoder(
            embedding_size=config.embedding_size,
            vocab_size=dataset.vocab_size,
            hidden_size=config.hidden_size,
            num_layers=config.num_layers,
            device=device
        )
        encoder_optim = th.optim.Adam(text_encoder.parameters(),
                                      lr=config.learning_rate,
                                      betas=(config.adam_beta1, config.adam_beta2),
                                      eps=config.eps)
    msg_gan = MSG_GAN(
        depth=config.depth,
        latent_size=config.latent_size,
        use_eql=config.use_eql,
        use_ema=config.use_ema,
        ema_decay=config.ema_decay,
        device=device)

    genoptim = th.optim.Adam(msg_gan.gen.parameters(), config.g_lr,
                              [config.adam_beta1, config.adam_beta2])

    disoptim = th.optim.Adam(msg_gan.dis.parameters(), config.d_lr,
                              [config.adam_beta1, config.adam_beta2])

    loss = lses.RelativisticAverageHingeGAN

    # create the networks

    if args.encoder_file is not None:
        # Note this should not be used with the pretrained encoder file
        print("Loading encoder from:", args.encoder_file)
        text_encoder.load_state_dict(th.load(args.encoder_file))

    condition_augmenter = ConditionAugmentor(
        input_size=config.hidden_size,
        latent_size=config.ca_out_size,
        use_eql=config.use_eql,
        device=device
    )

    if args.ca_file is not None:
        print("Loading conditioning augmenter from:", args.ca_file)
        condition_augmenter.load_state_dict(th.load(args.ca_file))

    if args.generator_file is not None:
        print("Loading generator from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    if args.discriminator_file is not None:
        print("Loading discriminator from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(condition_augmenter.parameters(),
                             lr=config.learning_rate,
                             betas=(config.adam_beta1, config.adam_beta2),
                             eps=config.eps)

    print("Generator Config:")
    print(msg_gan.gen)

    print("\nDiscriminator Config:")
    print(msg_gan.dis)

    # train all the networks
    train_networks(
        encoder=text_encoder,
        ca=condition_augmenter,
        msg_gan=msg_gan,
        dataset=dataset,
        encoder_optim=encoder_optim,
        ca_optim=ca_optim,
        gen_optim=genoptim,
        dis_optim=disoptim,
        loss_fn=loss(msg_gan.dis),
        epochs=config.epochs,
        fade_in_percentage=config.fade_in_percentage,
        start_depth=args.start_depth,
        batch_sizes=config.batch_sizes,
        num_workers=config.num_workers,
        feedback_factor=config.feedback_factor,
        log_dir=config.log_dir,
        sample_dir=config.sample_dir,
        checkpoint_factor=config.checkpoint_factor,
        save_dir=config.save_dir,
    )