コード例 #1
0
ファイル: train.py プロジェクト: manojtld/BMSG-GAN
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from MSG_GAN.GAN import MSG_GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader, FoldersDistributedDataset
    from MSG_GAN import Losses as lses

    # create a data source:
    data_source = FlatDirectoryImageDataset if not args.folder_distributed \
        else FoldersDistributedDataset

    dataset = data_source(args.images_dir,
                          transform=get_transform(
                              (int(np.power(2, args.depth + 1)),
                               int(np.power(2, args.depth + 1))),
                              flip_horizontal=args.flip_augment))

    data = get_data_loader(dataset, args.batch_size, args.num_workers)
    print("Total number of images in the dataset:", len(dataset))

    # create a gan from these
    msg_gan = MSG_GAN(depth=args.depth,
                      latent_size=args.latent_size,
                      use_eql=args.use_eql,
                      use_ema=args.use_ema,
                      ema_decay=args.ema_decay,
                      device=device)

    if args.generator_file is not None:
        # load the weights into generator
        print("loading generator_weights from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    print("Generator Configuration: ")
    # print(msg_gan.gen)

    if args.shadow_generator_file is not None:
        # load the weights into generator
        print("loading shadow_generator_weights from:",
              args.shadow_generator_file)
        msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file))

    if args.discriminator_file is not None:
        # load the weights into discriminator
        print("loading discriminator_weights from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    print("Discriminator Configuration: ")
    # print(msg_gan.dis)

    # create optimizer for generator:
    gen_optim = th.optim.Adam(msg_gan.gen.parameters(), args.g_lr,
                              [args.adam_beta1, args.adam_beta2])

    dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr,
                              [args.adam_beta1, args.adam_beta2])

    if args.generator_optim_file is not None:
        print("loading gen_optim_state from:", args.generator_optim_file)
        gen_optim.load_state_dict(th.load(args.generator_optim_file))

    if args.discriminator_optim_file is not None:
        print("loading dis_optim_state from:", args.discriminator_optim_file)
        dis_optim.load_state_dict(th.load(args.discriminator_optim_file))

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = lses.HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = lses.RelativisticAverageHingeGAN
    elif loss_name == "standard-gan":
        loss = lses.StandardGAN
    elif loss_name == "lsgan":
        loss = lses.LSGAN
    elif loss_name == "lsgan-sigmoid":
        loss = lses.LSGAN_SIGMOID
    elif loss_name == "wgan-gp":
        loss = lses.WGAN_GP
    else:
        raise Exception("Unknown loss function requested")

    # train the GAN
    msg_gan.train(data,
                  gen_optim,
                  dis_optim,
                  loss_fn=loss(msg_gan.dis),
                  num_epochs=args.num_epochs,
                  checkpoint_factor=args.checkpoint_factor,
                  data_percentage=args.data_percentage,
                  feedback_factor=args.feedback_factor,
                  num_samples=args.num_samples,
                  sample_dir=args.sample_dir,
                  save_dir=args.model_dir,
                  log_dir=args.model_dir,
                  start=args.start)
コード例 #2
0
ファイル: train.py プロジェクト: tomguluson92/Style-MSG-GAN
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from MSG_GAN.GAN import MSG_GAN
    from data_processing.DataLoader import FlatDirectoryImageDataset, \
        get_transform, get_data_loader, FoldersDistributedDataset, IgnoreLabels
    from MSG_GAN import Losses as lses

    # create a data source:
    if args.pytorch_dataset is None:
        data_source = FlatDirectoryImageDataset if not args.folder_distributed \
            else FoldersDistributedDataset

        dataset = data_source(args.images_dir,
                              transform=get_transform(
                                  (int(np.power(2, args.depth + 1)),
                                   int(np.power(2, args.depth + 1))),
                                  flip_horizontal=args.flip_augment))
    else:
        dataset_name = args.pytorch_dataset.lower()
        if dataset_name == "cifar10":
            dataset = IgnoreLabels(
                CIFAR10(args.dataset_dir,
                        transform=get_transform(
                            (int(np.power(2, args.depth + 1)),
                             int(np.power(2, args.depth + 1))),
                            flip_horizontal=args.flip_augment),
                        download=True))
        else:
            raise Exception("Unknown dataset requested")

    data = get_data_loader(dataset, args.batch_size, args.num_workers)
    print("Total number of images in the dataset:", len(dataset))

    # create a gan from these
    msg_gan = MSG_GAN(depth=args.depth,
                      latent_size=args.latent_size,
                      use_eql=args.use_eql,
                      use_ema=args.use_ema,
                      ema_decay=args.ema_decay,
                      device=device)

    if args.generator_file is not None:
        # load the weights into generator
        print("loading generator_weights from:", args.generator_file)
        msg_gan.gen.load_state_dict(th.load(args.generator_file))

    # print("Generator Configuration: ")
    # print(msg_gan.gen)

    if args.shadow_generator_file is not None:
        # load the weights into generator
        print("loading shadow_generator_weights from:",
              args.shadow_generator_file)
        msg_gan.gen_shadow.load_state_dict(th.load(args.shadow_generator_file))

    if args.discriminator_file is not None:
        # load the weights into discriminator
        print("loading discriminator_weights from:", args.discriminator_file)
        msg_gan.dis.load_state_dict(th.load(args.discriminator_file))

    # print("Discriminator Configuration: ")
    # print(msg_gan.dis)

    # create optimizer for generator:
    gen_params = [{
        'params': msg_gan.gen.style.parameters(),
        'lr': args.g_lr * 0.01,
        'mult': 0.01
    }, {
        'params': msg_gan.gen.layers.parameters(),
        'lr': args.g_lr
    }, {
        'params': msg_gan.gen.rgb_converters.parameters(),
        'lr': args.g_lr
    }]
    gen_optim = th.optim.Adam(gen_params, args.g_lr,
                              [args.adam_beta1, args.adam_beta2])

    dis_optim = th.optim.Adam(msg_gan.dis.parameters(), args.d_lr,
                              [args.adam_beta1, args.adam_beta2])

    if args.generator_optim_file is not None:
        print("loading gen_optim_state from:", args.generator_optim_file)
        gen_optim.load_state_dict(th.load(args.generator_optim_file))

    if args.discriminator_optim_file is not None:
        print("loading dis_optim_state from:", args.discriminator_optim_file)
        dis_optim.load_state_dict(th.load(args.discriminator_optim_file))

    loss_name = args.loss_function.lower()

    if loss_name == "hinge":
        loss = lses.HingeGAN
    elif loss_name == "relativistic-hinge":
        loss = lses.RelativisticAverageHingeGAN
    elif loss_name == "standard-gan":
        loss = lses.StandardGAN
    elif loss_name == "lsgan":
        loss = lses.LSGAN
    elif loss_name == "lsgan-sigmoid":
        loss = lses.LSGAN_SIGMOID
    elif loss_name == "wgan-gp":
        loss = lses.WGAN_GP
    else:
        raise Exception("Unknown loss function requested")

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    if args.pytorch_dataset is not None:
        dataName = 'cifar'
    elif args.images_dir.find('celeb') != -1:
        dataName = 'celeb'
    else:
        dataName = 'flowers'
    output_dir = 'output/%s_%s_%s' % \
        ('attnmsggan', dataName, timestamp)
    args.model_dir = output_dir + '/models'
    args.sample_dir = output_dir + '/images'
    args.log_dir = output_dir + '/logs'

    # train the GAN
    msg_gan.train(data,
                  gen_optim,
                  dis_optim,
                  loss_fn=loss(msg_gan.dis),
                  num_epochs=args.num_epochs,
                  checkpoint_factor=args.checkpoint_factor,
                  data_percentage=args.data_percentage,
                  feedback_factor=args.feedback_factor,
                  num_samples=args.num_samples,
                  sample_dir=args.sample_dir,
                  save_dir=args.model_dir,
                  log_dir=args.log_dir,
                  start=args.start)