コード例 #1
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    print(args)

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        gan = infoGAN(args, SUPERVISED=False)
    elif args.gan_type == 'EBGAN':
        gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    elif args.gan_type == 'DRAGAN':
        gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        generator = model.InfoGANGenerator(input_dim=62,
                                           output_dim=3,
                                           input_size=args.input_size)
        discriminator = model.InfoGANDiscriminator(input_dim=3,
                                                   output_dim=1,
                                                   input_size=args.input_size)
        gan = LSGAN(args, generator, discriminator)
    elif args.gan_type == 'LSGAN_classifier':
        generator = model.InfoGANGenerator(input_dim=62,
                                           output_dim=3,
                                           input_size=args.input_size)
        discriminator = model.InfoGANDiscriminatorClassifier(
            input_dim=3,
            output_dim=1,
            input_size=args.input_size,
            save_dir=args.save_dir,
            model_name=args.gan_type)
        gan = LSGAN(args, generator, discriminator)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
コード例 #2
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        gan = infoGAN(args, SUPERVISED=True)
    elif args.gan_type == 'EBGAN':
        gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'DRAGAN':
        gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
コード例 #3
0
ファイル: main.py プロジェクト: yaodongyu/pytorch-GANs
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        gan = infoGAN(args, SUPERVISED=False)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    elif args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
コード例 #4
0
ファイル: main.py プロジェクト: AIMarkov/GAN
def main():
    # parse arguments

    args = parse_args()
    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True
        '''
        如果网络的输入数据维度或类型上变化不大,设置
        torch.backends.cudnn.benchmark = true
        可以增加运行效率;如果网络的输入数据在每次
        iteration都变化的话,会导致cudnn
        每次都会去寻找一遍最优配置,这样反而会降低运行效率。
        '''
        # declare instance for GAN
    if args.gan_type == 'GAN':
        print("GAN is "+args.gan_type)
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        print("GAN is " + args.gan_type)
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        print("GAN is " + args.gan_type)
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        print("GAN is " + args.gan_type)
        gan = infoGAN(args, SUPERVISED=False)
    elif args.gan_type == 'EBGAN':
        print("GAN is " + args.gan_type)
        gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        print("GAN is " + args.gan_type)
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        print("GAN is " + args.gan_type)
        gan = WGAN_GP(args)
    elif args.gan_type == 'DRAGAN':
        print("GAN is " + args.gan_type)
        gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        print("GAN is " + args.gan_type)
        gan = LSGAN(args)
    elif args.gan_type == 'BEGAN':
        print("GAN is " + args.gan_type)
        gan = BEGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
コード例 #5
0
def main():
    args = parse_args()
    data_loader = Data.DataLoader(datasets.FashionMNIST(
        root='data/fashion-MNIST',
        train=True,
        download=True,
        transform=transforms.ToTensor(),
    ),
                                  shuffle=True,
                                  batch_size=args.batch_size)
    if args.gan_type == "GAN":
        gan = GAN(args, data_loader)
    elif args.gan_type == "LSGAN":
        gan = LSGAN(args, data_loader)

    gan.train()
コード例 #6
0
def main():
    """main"""

    # parse arguments

    args = parse_args()
    print('Training {},started at {}'.format(
        args.gan_type, time.asctime(time.localtime(time.time()))))

    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    # elif args.gan_type == 'CGAN':
    #     gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    # elif args.gan_type == 'infoGAN':
    #     gan = infoGAN(args, SUPERVISED=False)
    # elif args.gan_type == 'EBGAN':
    #     gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    # elif args.gan_type == 'DRAGAN':
    #     gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session

    # return
    gan.train()
    print('Training {},finished at {}'.format(
        args.gan_type, time.asctime(time.localtime(time.time()))))
コード例 #7
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # declare instance for GAN
        if args.gan_type == 'GAN':
            gan = GAN(sess,
                      epoch=args.epoch,
                      batch_size=args.batch_size,
                      z_dim=args.z_dim,
                      dataset_name=args.dataset,
                      checkpoint_dir=args.checkpoint_dir,
                      result_dir=args.result_dir,
                      log_dir=args.log_dir)
        elif args.gan_type == 'CGAN':
            gan = CGAN(sess,
                       epoch=args.epoch,
                       batch_size=args.batch_size,
                       z_dim=args.z_dim,
                       dataset_name=args.dataset,
                       checkpoint_dir=args.checkpoint_dir,
                       result_dir=args.result_dir,
                       log_dir=args.log_dir)
        elif args.gan_type == 'ACGAN':
            gan = ACGAN(sess,
                        epoch=args.epoch,
                        batch_size=args.batch_size,
                        z_dim=args.z_dim,
                        dataset_name=args.dataset,
                        checkpoint_dir=args.checkpoint_dir,
                        result_dir=args.result_dir,
                        log_dir=args.log_dir)
        elif args.gan_type == 'infoGAN':
            gan = infoGAN(sess,
                          epoch=args.epoch,
                          batch_size=args.batch_size,
                          z_dim=args.z_dim,
                          dataset_name=args.dataset,
                          checkpoint_dir=args.checkpoint_dir,
                          result_dir=args.result_dir,
                          log_dir=args.log_dir)
        elif args.gan_type == 'EBGAN':
            gan = EBGAN(sess,
                        epoch=args.epoch,
                        batch_size=args.batch_size,
                        z_dim=args.z_dim,
                        dataset_name=args.dataset,
                        checkpoint_dir=args.checkpoint_dir,
                        result_dir=args.result_dir,
                        log_dir=args.log_dir)
        elif args.gan_type == 'WGAN':
            gan = WGAN(sess,
                       epoch=args.epoch,
                       batch_size=args.batch_size,
                       z_dim=args.z_dim,
                       dataset_name=args.dataset,
                       checkpoint_dir=args.checkpoint_dir,
                       result_dir=args.result_dir,
                       log_dir=args.log_dir)
        elif args.gan_type == 'WGAN_GP':
            gan = WGAN_GP(sess,
                          epoch=args.epoch,
                          batch_size=args.batch_size,
                          z_dim=args.z_dim,
                          dataset_name=args.dataset,
                          checkpoint_dir=args.checkpoint_dir,
                          result_dir=args.result_dir,
                          log_dir=args.log_dir)
        elif args.gan_type == 'DRAGAN':
            gan = DRAGAN(sess,
                         epoch=args.epoch,
                         batch_size=args.batch_size,
                         z_dim=args.z_dim,
                         dataset_name=args.dataset,
                         checkpoint_dir=args.checkpoint_dir,
                         result_dir=args.result_dir,
                         log_dir=args.log_dir)
        elif args.gan_type == 'LSGAN':
            gan = LSGAN(sess,
                        epoch=args.epoch,
                        batch_size=args.batch_size,
                        z_dim=args.z_dim,
                        dataset_name=args.dataset,
                        checkpoint_dir=args.checkpoint_dir,
                        result_dir=args.result_dir,
                        log_dir=args.log_dir)
        elif args.gan_type == 'BEGAN':
            gan = BEGAN(sess,
                        epoch=args.epoch,
                        batch_size=args.batch_size,
                        z_dim=args.z_dim,
                        dataset_name=args.dataset,
                        checkpoint_dir=args.checkpoint_dir,
                        result_dir=args.result_dir,
                        log_dir=args.log_dir)
        elif args.gan_type == 'VAE':
            gan = VAE(sess,
                      epoch=args.epoch,
                      batch_size=args.batch_size,
                      z_dim=args.z_dim,
                      dataset_name=args.dataset,
                      checkpoint_dir=args.checkpoint_dir,
                      result_dir=args.result_dir,
                      log_dir=args.log_dir)
        elif args.gan_type == 'CVAE':
            gan = CVAE(sess,
                       epoch=args.epoch,
                       batch_size=args.batch_size,
                       z_dim=args.z_dim,
                       dataset_name=args.dataset,
                       checkpoint_dir=args.checkpoint_dir,
                       result_dir=args.result_dir,
                       log_dir=args.log_dir)
        elif args.gan_type == 'VAE_GAN':
            gan = VAE_GAN(sess,
                          epoch=args.epoch,
                          batch_size=args.batch_size,
                          z_dim=args.z_dim,
                          dataset_name=args.dataset,
                          checkpoint_dir=args.checkpoint_dir,
                          result_dir=args.result_dir,
                          log_dir=args.log_dir)
        else:
            raise Exception("[!] There is no option for " + args.gan_type)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        # launch the graph in a session
        gan.train()
        print(" [*] Training finished!")

        # visualize learned generator
        gan.visualize_results(args.epoch - 1)
        print(" [*] Testing finished!")
コード例 #8
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        gan = infoGAN(args, SUPERVISED=False)
    elif args.gan_type == 'EBGAN':
        gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    elif args.gan_type == 'DRAGAN':
        gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
    elif args.gan_type == 'TOGAN':
        gan = TOGAN(args)
    elif args.gan_type == 'CVAE':
        gan = CVAE(args)
    elif args.gan_type == None:
        pass
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

    if args.use_fake_data:
        fakedata = gan.load()
    else:
        fakedata = None

    if args.clf_type == 'clf':
        clf = CLF(args, fakedata)
        clf.load()
        # clf.train()

    else:
        gan.train()
        # gan()
        # gan.load()

    # launch the graph in a session
    # clf.train()
    print(" [*] Training finished!")

    # visualize learned generator
    # gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
コード例 #9
0
def main():
    # parse arguments
    opts = parse_args()
    if opts is None:
        exit()

        # declare instance for GAN
    if opts.gan_type == 'GAN':
        gan = GAN(opts)
    elif opts.gan_type == 'CGAN':
        gan = CGAN(opts)
    elif opts.gan_type == 'ACGAN':
        gan = ACGAN(opts)
    elif opts.gan_type == 'infoGAN':
        gan = infoGAN(opts, SUPERVISED=True)
    elif opts.gan_type == 'EBGAN':
        gan = EBGAN(opts)
    elif opts.gan_type == 'WGAN':
        gan = WGAN(opts)
    elif opts.gan_type == 'WGAN_GP':
        gan = WGAN_GP(opts)
    elif opts.gan_type == 'DRAGAN':
        gan = DRAGAN(opts)
    elif opts.gan_type == 'LSGAN':
        gan = LSGAN(opts)
    elif opts.gan_type == 'BEGAN':
        gan = BEGAN(opts)
    elif opts.gan_type == 'DRGAN':
        gan = DRGAN(opts)
    elif opts.gan_type == 'AE':
        gan = AutoEncoder(opts)
    elif opts.gan_type == 'GAN3D':
        gan = GAN3D(opts)
    elif opts.gan_type == 'VAEGAN3D':
        gan = VAEGAN3D(opts)
    elif opts.gan_type == 'DRGAN3D':
        gan = DRGAN3D(opts)
    elif opts.gan_type == 'Recog3D':
        gan = Recog3D(opts)
    elif opts.gan_type == 'Recog2D':
        gan = Recog2D(opts)
    elif opts.gan_type == 'VAEDRGAN3D':
        gan = VAEDRGAN3D(opts)
    elif opts.gan_type == 'DRcycleGAN3D':
        gan = DRcycleGAN3D(opts)
    elif opts.gan_type == 'CycleGAN3D':
        gan = CycleGAN3D(opts)
    elif opts.gan_type == 'AE3D':
        gan = AutoEncoder3D(opts)
    elif opts.gan_type == 'DRGAN2D':
        gan = DRGAN2D(opts)
    elif opts.gan_type == 'DRecon3DGAN':
        gan = DRecon3DGAN(opts)
    elif opts.gan_type == 'DRecon2DGAN':
        gan = DRecon2DGAN(opts)
    elif opts.gan_type == 'DReconVAEGAN':
        gan = DReconVAEGAN(opts)
    else:
        raise Exception("[!] There is no option for " + opts.gan_type)

    if opts.resume or len(opts.eval) > 0:
        print(" [*] Loading saved model...")
        gan.load()
        print(" [*] Loading finished!")

    # launch the graph in a session
    if len(opts.eval) == 0:
        gan.train()
        print(" [*] Training finished!")
    else:
        print(" [*] Training skipped!")

    # visualize learned generator
    if len(opts.eval) == 0:
        print(" [*] eval mode is not specified!")
    else:
        if opts.eval == 'generate':
            gan.visualize_results(opts.epoch)
        elif opts.eval == 'interp_z':
            gan.interpolate_z(opts)
        elif opts.eval == 'interp_id':
            gan.interpolate_id(opts)
        elif opts.eval == 'interp_expr':
            gan.interpolate_expr(opts)
        elif opts.eval == 'recon':
            gan.reconstruct()
        elif opts.eval == 'control_expr':
            gan.control_expr()
        else:
            gan.manual_inference(opts)
        print(" [*] Testing finished!")