Exemple #1
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from attn_gan_pytorch.Utils import get_layer
    from attn_gan_pytorch.ConfigManagement import get_config
    from attn_gan_pytorch.Networks import Generator, Discriminator, GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader
    from attn_gan_pytorch.Losses import RelativisticAverageHingeGAN

    # create a data source:
    celeba_dataset = FlatDirectoryImageDataset(args.images_dir,
                                               transform=get_transform((64, 64)))
    data = get_data_loader(celeba_dataset, args.batch_size, args.num_workers)

    # create generator object:
    gen_conf = get_config(args.generator_config)
    gen_conf = list(map(get_layer, gen_conf.architecture))
    generator = Generator(gen_conf, args.latent_size)

    print("Generator Configuration: ")
    print(generator)

    # create discriminator object:
    dis_conf = get_config(args.discriminator_config)
    dis_conf = list(map(get_layer, dis_conf.architecture))
    discriminator = Discriminator(dis_conf)

    print("Discriminator Configuration: ")
    print(discriminator)

    # create a gan from these
    sagan = GAN(generator, discriminator, device=device)

    # create optimizer for generator:
    gen_optim = th.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()),
                              args.g_lr, [0, 0.9])

    dis_optim = th.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()),
                              args.d_lr, [0, 0.9])

    # train the GAN
    sagan.train(
        data,
        gen_optim,
        dis_optim,
        loss_fn=RelativisticAverageHingeGAN(device, discriminator),
        num_epochs=args.num_epochs,
        checkpoint_factor=args.checkpoint_factor,
        data_percentage=args.data_percentage,
        feedback_factor=31,
        num_samples=64,
        save_dir="models/relativistic/",
        sample_dir="samples/4/",
        log_dir="models/relativistic"
    )
Exemple #2
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from SMSG_GAN.SMSG_GAN import SMSG_GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader
    from SMSG_GAN.Losses import HingeGAN, RelativisticAverageHingeGAN, \
        StandardGAN, LSGAN

    # create a data source:
    celeba_dataset = FlatDirectoryImageDataset(
        args.images_dir,
        transform=get_transform(
            (int(np.power(2, args.depth + 1)), int(np.power(2,
                                                            args.depth + 1)))))

    data = get_data_loader(celeba_dataset, args.batch_size, args.num_workers)

    # create a gan from these
    smsg_gan = SMSG_GAN(depth=args.depth,
                        latent_size=args.latent_size,
                        device=device)

    if args.generator_file is not None:
        # load the weights into generator
        smsg_gan.gen.load_state_dict(th.load(args.generator_file))

    print("Generator Configuration: ")
    print(smsg_gan.gen)

    if args.discriminator_file is not None:
        # load the weights into discriminator
        smsg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Discriminator Configuration: ")
    print(smsg_gan.dis)

    # create optimizer for generator:
    gen_optim = th.optim.Adam(smsg_gan.gen.parameters(), args.g_lr, [0, 0.99])

    dis_optim = th.optim.Adam(smsg_gan.dis.parameters(), args.d_lr, [0, 0.99])

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = RelativisticAverageHingeGAN
    elif loss_name == "standard-gan":
        loss = StandardGAN
    elif loss_name == "lsgan":
        loss = LSGAN
    else:
        raise Exception("Unknown loss function requested")

    # train the GAN
    smsg_gan.train(data,
                   gen_optim,
                   dis_optim,
                   loss_fn=loss(device, smsg_gan.dis),
                   num_epochs=args.num_epochs,
                   checkpoint_factor=args.checkpoint_factor,
                   data_percentage=args.data_percentage,
                   feedback_factor=args.feedback_factor,
                   num_samples=args.num_samples,
                   sample_dir=args.sample_dir,
                   save_dir=args.model_dir,
                   log_dir=args.model_dir,
                   start=args.start)
Exemple #3
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from MSG_GAN.GAN import MSG_GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader, FoldersDistributedDataset
    from MSG_GAN import Losses as lses

    # create a data source:
    data_source = FlatDirectoryImageDataset if not args.folder_distributed \
        else FoldersDistributedDataset

    dataset = data_source(args.images_dir,
                          transform=get_transform(
                              (int(np.power(2, args.depth + 1)),
                               int(np.power(2, args.depth + 1))),
                              flip_horizontal=args.flip_augment))

    data = get_data_loader(dataset, args.batch_size, args.num_workers)
    print("Total number of images in the dataset:", len(dataset))

    # create a gan from these
    msg_gan = MSG_GAN(depth=args.depth,
                      latent_size=args.latent_size,
                      use_eql=args.use_eql,
                      use_ema=args.use_ema,
                      ema_decay=args.ema_decay,
                      device=device)

    if args.generator_file is not None:
        # load the weights into generator
        print("loading generator_weights from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    print("Generator Configuration: ")
    # print(msg_gan.gen)

    if args.shadow_generator_file is not None:
        # load the weights into generator
        print("loading shadow_generator_weights from:",
              args.shadow_generator_file)
        msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file))

    if args.discriminator_file is not None:
        # load the weights into discriminator
        print("loading discriminator_weights from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Discriminator Configuration: ")
    # print(msg_gan.dis)

    # create optimizer for generator:
    gen_optim = th.optim.Adam(msg_gan.gen.parameters(), args.g_lr,
                              [args.adam_beta1, args.adam_beta2])

    dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr,
                              [args.adam_beta1, args.adam_beta2])

    if args.generator_optim_file is not None:
        print("loading gen_optim_state from:", args.generator_optim_file)
        gen_optim.load_state_dict(th.load(args.generator_optim_file))

    if args.discriminator_optim_file is not None:
        print("loading dis_optim_state from:", args.discriminator_optim_file)
        dis_optim.load_state_dict(th.load(args.discriminator_optim_file))

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = lses.HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = lses.RelativisticAverageHingeGAN
    elif loss_name == "standard-gan":
        loss = lses.StandardGAN
    elif loss_name == "lsgan":
        loss = lses.LSGAN
    elif loss_name == "lsgan-sigmoid":
        loss = lses.LSGAN_SIGMOID
    elif loss_name == "wgan-gp":
        loss = lses.WGAN_GP
    else:
        raise Exception("Unknown loss function requested")

    # train the GAN
    msg_gan.train(data,
                  gen_optim,
                  dis_optim,
                  loss_fn=loss(msg_gan.dis),
                  num_epochs=args.num_epochs,
                  checkpoint_factor=args.checkpoint_factor,
                  data_percentage=args.data_percentage,
                  feedback_factor=args.feedback_factor,
                  num_samples=args.num_samples,
                  sample_dir=args.sample_dir,
                  save_dir=args.model_dir,
                  log_dir=args.model_dir,
                  start=args.start)
Exemple #4
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from vdb.Gan import GAN
    from vdb.Gan_networks import Generator, Discriminator
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader, FoldersDistributedDataset
    from vdb.Losses import WGAN_GP, HingeGAN, RelativisticAverageHingeGAN, \
        StandardGAN, LSGAN, StandardGANWithSigmoid

    # create a data source:
    if args.folder_distributed_dataset:
        data_extractor = FoldersDistributedDataset
    else:
        data_extractor = FlatDirectoryImageDataset

    dataset = data_extractor(args.images_dir,
                             get_transform((args.size, args.size)))

    print("Total number of images in the dataset:", len(dataset))

    data = get_data_loader(dataset, args.batch_size, args.num_workers)

    # create the Generator and Discriminator objects:
    generator = Generator(args.latent_size, args.size, args.final_channels,
                          args.max_channels).to(device)

    discriminator = Discriminator(args.size, args.final_channels,
                                  args.max_channels).to(device)

    # create a gan from these
    vdb_gan = GAN(gen=generator, dis=discriminator, device=device)

    if args.generator_file is not None:
        # load the weights into generator
        print("loading generator weights from:", args.generator_file)
        vdb_gan.gen.load_state_dict(th.load(args.generator_file))

    print("Generator Configuration: ")
    print(vdb_gan.gen)

    if args.discriminator_file is not None:
        # load the weights into discriminator
        print("loading discriminator weights from:", args.discriminator_file)
        vdb_gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Discriminator Configuration: ")
    print(vdb_gan.dis)

    # create optimizer for generator:
    gen_optim = th.optim.RMSprop(vdb_gan.gen.parameters(), lr=args.g_lr)

    if args.gen_optim_file is not None:
        print("loading state of the gen optimizer from:", args.gen_optim_file)
        gen_optim.load_state_dict(th.load(args.gen_optim_file))

    dis_optim = th.optim.RMSprop(vdb_gan.dis.parameters(), lr=args.d_lr)

    if args.dis_optim_file is not None:
        print("loading state of the dis optimizer from:", args.dis_optim_file)
        dis_optim.load_state_dict(th.load(args.dis_optim_file))

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = RelativisticAverageHingeGAN
    elif loss_name == "standard-gan":
        loss = StandardGAN
    elif loss_name == "standard-gan_with-sigmoid":
        loss = StandardGANWithSigmoid
    elif loss_name == "wgan-gp":
        loss = WGAN_GP
    elif loss_name == "lsgan":
        loss = LSGAN
    else:
        raise Exception("Unknown loss function requested")

    # train the GAN
    vdb_gan.train(data=data,
                  gen_optim=gen_optim,
                  dis_optim=dis_optim,
                  loss_fn=loss(vdb_gan.dis),
                  init_beta=args.init_beta,
                  i_c=args.i_c,
                  latent_distrib=args.latent_distrib,
                  start=args.start,
                  num_epochs=args.num_epochs,
                  feedback_factor=args.feedback_factor,
                  checkpoint_factor=args.checkpoint_factor,
                  data_percentage=args.data_percentage,
                  num_samples=args.num_samples,
                  log_dir=args.model_dir,
                  sample_dir=args.sample_dir,
                  save_dir=args.model_dir)
Exemple #5
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from MSG_GAN.GAN import MSG_GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader, FoldersDistributedDataset, IgnoreLabels
    from MSG_GAN import Losses as lses

    # create a data source:
    if args.pytorch_dataset is None:
        data_source = FlatDirectoryImageDataset if not args.folder_distributed \
            else FoldersDistributedDataset

        dataset = data_source(args.images_dir,
                              transform=get_transform(
                                  (int(np.power(2, args.depth + 1)),
                                   int(np.power(2, args.depth + 1))),
                                  flip_horizontal=args.flip_augment))
    else:
        dataset_name = args.pytorch_dataset.lower()
        if dataset_name == "cifar10":
            dataset = IgnoreLabels(
                CIFAR10(args.dataset_dir,
                        transform=get_transform(
                            (int(np.power(2, args.depth + 1)),
                             int(np.power(2, args.depth + 1))),
                            flip_horizontal=args.flip_augment),
                        download=True))
        else:
            raise Exception("Unknown dataset requested")

    data = get_data_loader(dataset, args.batch_size, args.num_workers)
    print("Total number of images in the dataset:", len(dataset))

    # create a gan from these
    msg_gan = MSG_GAN(depth=args.depth,
                      latent_size=args.latent_size,
                      use_eql=args.use_eql,
                      use_ema=args.use_ema,
                      ema_decay=args.ema_decay,
                      device=device)

    if args.generator_file is not None:
        # load the weights into generator
        print("loading generator_weights from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    # print("Generator Configuration: ")
    # print(msg_gan.gen)

    if args.shadow_generator_file is not None:
        # load the weights into generator
        print("loading shadow_generator_weights from:",
              args.shadow_generator_file)
        msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file))

    if args.discriminator_file is not None:
        # load the weights into discriminator
        print("loading discriminator_weights from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # print("Discriminator Configuration: ")
    # print(msg_gan.dis)

    # create optimizer for generator:
    gen_params = [{
        'params': msg_gan.gen.style.parameters(),
        'lr': args.g_lr * 0.01,
        'mult': 0.01
    }, {
        'params': msg_gan.gen.layers.parameters(),
        'lr': args.g_lr
    }, {
        'params': msg_gan.gen.rgb_converters.parameters(),
        'lr': args.g_lr
    }]
    gen_optim = th.optim.Adam(gen_params, args.g_lr,
                              [args.adam_beta1, args.adam_beta2])

    dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr,
                              [args.adam_beta1, args.adam_beta2])

    if args.generator_optim_file is not None:
        print("loading gen_optim_state from:", args.generator_optim_file)
        gen_optim.load_state_dict(th.load(args.generator_optim_file))

    if args.discriminator_optim_file is not None:
        print("loading dis_optim_state from:", args.discriminator_optim_file)
        dis_optim.load_state_dict(th.load(args.discriminator_optim_file))

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = lses.HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = lses.RelativisticAverageHingeGAN
    elif loss_name == "standard-gan":
        loss = lses.StandardGAN
    elif loss_name == "lsgan":
        loss = lses.LSGAN
    elif loss_name == "lsgan-sigmoid":
        loss = lses.LSGAN_SIGMOID
    elif loss_name == "wgan-gp":
        loss = lses.WGAN_GP
    else:
        raise Exception("Unknown loss function requested")

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    if args.pytorch_dataset is not None:
        dataName = 'cifar'
    elif args.images_dir.find('celeb') != -1:
        dataName = 'celeb'
    else:
        dataName = 'flowers'
    output_dir = 'output/%s_%s_%s' % \
        ('attnmsggan', dataName, timestamp)
    args.model_dir = output_dir + '/models'
    args.sample_dir = output_dir + '/images'
    args.log_dir = output_dir + '/logs'

    # train the GAN
    msg_gan.train(data,
                  gen_optim,
                  dis_optim,
                  loss_fn=loss(msg_gan.dis),
                  num_epochs=args.num_epochs,
                  checkpoint_factor=args.checkpoint_factor,
                  data_percentage=args.data_percentage,
                  feedback_factor=args.feedback_factor,
                  num_samples=args.num_samples,
                  sample_dir=args.sample_dir,
                  save_dir=args.model_dir,
                  log_dir=args.log_dir,
                  start=args.start)
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from GAN.GAN import ConditionalGAN
    from data_processing.DataLoader import get_transform, get_data_loader, \
        RawTextFace2TextDataset
    from GAN import Losses as lses
    from GAN.TextEncoder import PretrainedEncoder
    from GAN.ConditionAugmentation import ConditionAugmentor

    base_dir = os.environ["SM_CHANNEL_TRAINING"]

    images_dir_path = os.path.join(base_dir, "data")
    annotation_file = os.path.join(base_dir, "face2text_v1.0/raw.json")
    encoder_file = os.path.join(base_dir, "infersent2/infersent2.pkl")
    embedding_file = os.path.join(base_dir, "fasttext/crawl-300d-2M.vec")

    # transformation routine:
    res = int(np.power(2, args.depth + 1))
    img_transform = get_transform((res, res),
                                  flip_horizontal=args.flip_augment)

    # create a data source:
    dataset = RawTextFace2TextDataset(
        annots_file=annotation_file,
        img_dir=images_dir_path,
        img_transform=img_transform,
    )

    # create a new session object for the pretrained encoder:
    text_encoder = PretrainedEncoder(
        model_file=encoder_file,
        embedding_file=embedding_file,
        device=device,
    )
    encoder_optim = None

    data = get_data_loader(dataset, args.batch_size, args.num_workers)
    print("Total number of images in the dataset:", len(dataset))

    # create a gan from these
    gan = ConditionalGAN(depth=args.depth,
                         latent_size=args.latent_size,
                         ca_hidden_size=args.ca_hidden_size,
                         ca_out_size=args.ca_out_size,
                         loss_fn=args.loss_function,
                         use_eql=args.use_eql,
                         use_ema=args.use_ema,
                         ema_decay=args.ema_decay,
                         device=device)

    if args.ca_file is not None:
        print("loading conditioning augmenter from:", args.ca_file)
        gan.ca.load_state_dict(th.load(args.ca_file))

    print("Augmentor Configuration: ")
    print(gan.ca)

    if args.generator_file is not None:
        # load the weights into generator
        print("loading generator_weights from:", args.generator_file)
        gan.gen.load_state_dict(th.load(args.generator_file))

    print("Generator Configuration: ")
    print(gan.gen)

    if args.shadow_generator_file is not None:
        # load the weights into generator
        print("loading shadow_generator_weights from:",
              args.shadow_generator_file)
        gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file))

    if args.discriminator_file is not None:
        # load the weights into discriminator
        print("loading discriminator_weights from:", args.discriminator_file)
        gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Discriminator Configuration: ")
    print(gan.dis)

    # create the optimizer for Condition Augmenter separately
    ca_optim = th.optim.Adam(gan.ca.parameters(),
                             lr=args.a_lr,
                             betas=[args.adam_beta1, args.adam_beta2])

    # create optimizer forImportError: No module named 'networks.pro_gan_pytorch' generator:
    gen_optim = th.optim.Adam(gan.gen.parameters(), args.g_lr,
                              [args.adam_beta1, args.adam_beta2])

    dis_optim = th.optim.Adam(gan.dis.parameters(), args.d_lr,
                              [args.adam_beta1, args.adam_beta2])

    if args.generator_optim_file is not None:
        print("loading gen_optim_state from:", args.generator_optim_file)
        gen_optim.load_state_dict(th.load(args.generator_optim_file))

    if args.discriminator_optim_file is not None:
        print("loading dis_optim_state from:", args.discriminator_optim_file)
        dis_optim.load_state_dict(th.load(args.discriminator_optim_file))

    # train the GAN
    gan.train(data,
              gen_optim,
              dis_optim,
              ca_optim,
              text_encoder,
              encoder_optim,
              num_epochs=args.num_epochs,
              checkpoint_factor=args.checkpoint_factor,
              data_percentage=args.data_percentage,
              feedback_factor=args.feedback_factor,
              num_samples=args.num_samples,
              sample_dir=args.sample_dir,
              save_dir=args.model_dir,
              log_dir=args.model_dir,
              start=args.start,
              spoofing_factor=args.spoofing_factor,
              log_fid_values=args.log_fid_values,
              num_fid_images=args.num_fid_images,
              fid_temp_folder=args.fid_temp_folder,
              fid_real_stats=args.fid_real_stats,
              fid_batch_size=args.fid_batch_size)
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--pretrain', action='store_true', default=False)

    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader, FoldersDistributedDataset

    data_source = FlatDirectoryImageDataset

    images_dir = "../BMSG-GAN/sourcecode/flowers/data/jpg"
    dataset = data_source(images_dir, transform=get_transform((256, 256)))
    loader = get_data_loader(dataset, 6, 4)

    ae_model = Style_AutoEncoder().to(device)

    if args.pretrain:
        print('Weights Loading....')
        ae_model.load_state_dict(torch.load('models/AutoEncoder/ae.pth'))

    optimizer = optim.Adam(ae_model.parameters(), lr=0.003, amsgrad=True)

    style_img = read_picture('style_samples/style_pictures/sky_2.jpg',
                             normalization=True).to(device)
    print(style_img.size())
    transformer = style_transfer(style_img).to(device)
Exemple #8
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from attn_gan_pytorch.Utils import get_layer
    from attn_gan_pytorch.ConfigManagement import get_config
    from attn_gan_pytorch.Networks import Generator, Discriminator, GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader
    from attn_gan_pytorch.Losses import HingeGAN, RelativisticAverageHingeGAN

    # create a data source:
    celeba_dataset = FlatDirectoryImageDataset(args.images_dir,
                                               transform=get_transform((64, 64)))
    data = get_data_loader(celeba_dataset, args.batch_size, args.num_workers)

    # create generator object:
    gen_conf = get_config(args.generator_config)
    gen_conf = list(map(get_layer, gen_conf.architecture))
    generator = Generator(gen_conf, args.latent_size)

    if args.generator_file is not None:
        # load the weights into generator
        generator.load_state_dict(th.load(args.generator_file))

    print("Generator Configuration: ")
    print(generator)

    # create discriminator object:
    dis_conf = get_config(args.discriminator_config)
    dis_conf = list(map(get_layer, dis_conf.architecture))
    discriminator = Discriminator(dis_conf)

    if args.discriminator_file is not None:
        # load the weights into discriminator
        discriminator.load_state_dict(th.load(args.discriminator_file))

    print("Discriminator Configuration: ")
    print(discriminator)

    # create a gan from these
    fagan = GAN(generator, discriminator, device=device)

    # create optimizer for generator:
    gen_optim = th.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()),
                              args.g_lr, [0, 0.9])

    dis_optim = th.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()),
                              args.d_lr, [0, 0.9])

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = RelativisticAverageHingeGAN
    else:
        raise Exception("Unknown loss function requested")

    # train the GAN
    fagan.train(
        data,
        gen_optim,
        dis_optim,
        loss_fn=loss(device, discriminator),
        num_epochs=args.num_epochs,
        checkpoint_factor=args.checkpoint_factor,
        data_percentage=args.data_percentage,
        feedback_factor=args.feedback_factor,
        num_samples=64,
        sample_dir=args.sample_dir,
        save_dir=args.model_dir,
        log_dir=args.model_dir,
        start=args.start
    )