Example #1
0
def main(_):
  if not os.path.exists('./test'):
    os.makedirs('./test')
   
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  os.environ["CUDA_VISIBLE_DEVICES"] = cfg.device_id
  
  cfg.batch_size = 1
  net = WGAN_GP()
  
  profile = tf.placeholder(tf.float32, [1,224,224,3], name='profile')
  front = tf.placeholder(tf.float32, [1,224,224,3], name='profile')
  net.build_up(profile, front)
  
  print('Load Finetuned Model Successfully!')
  
  # Train or Test
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(config=config, graph=net.graph) as sess:
    sess.run(tf.global_variables_initializer())
    
    saver = tf.train.Saver(max_to_keep=0)  #
    saver.restore(sess, cfg.checkpoint_ft)
    
    test_list = np.loadtxt(cfg.test_list, dtype='string',delimiter=',')
    for img in test_list[:50]:
      print(img)
      img_np = read_img(cfg.test_path, img)
      img_gen = sess.run(net.gen_p, {profile:img_np, net.is_train:False}) #
      save_img('test', img, img_np, img_gen)
Example #2
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'EBGAN':
        gan = EBGAN(args)
        if load_flag == True:
            EBGAN.load(gan)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
        if load_flag == True:
            WGAN_GP.load(gan)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
        if load_flag == True:
            BEGAN.load(gan)
            print('load successful')
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session

    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
Example #3
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        gan = infoGAN(args, SUPERVISED=False)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    elif args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
Example #4
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    print(args)

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        gan = infoGAN(args, SUPERVISED=False)
    elif args.gan_type == 'EBGAN':
        gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    elif args.gan_type == 'DRAGAN':
        gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        generator = model.InfoGANGenerator(input_dim=62,
                                           output_dim=3,
                                           input_size=args.input_size)
        discriminator = model.InfoGANDiscriminator(input_dim=3,
                                                   output_dim=1,
                                                   input_size=args.input_size)
        gan = LSGAN(args, generator, discriminator)
    elif args.gan_type == 'LSGAN_classifier':
        generator = model.InfoGANGenerator(input_dim=62,
                                           output_dim=3,
                                           input_size=args.input_size)
        discriminator = model.InfoGANDiscriminatorClassifier(
            input_dim=3,
            output_dim=1,
            input_size=args.input_size,
            save_dir=args.save_dir,
            model_name=args.gan_type)
        gan = LSGAN(args, generator, discriminator)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
Example #5
0
File: main.py Project: AIMarkov/GAN
def main():
    # parse arguments

    args = parse_args()
    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True
        '''
        如果网络的输入数据维度或类型上变化不大,设置
        torch.backends.cudnn.benchmark = true
        可以增加运行效率;如果网络的输入数据在每次
        iteration都变化的话,会导致cudnn
        每次都会去寻找一遍最优配置,这样反而会降低运行效率。
        '''
        # declare instance for GAN
    if args.gan_type == 'GAN':
        print("GAN is "+args.gan_type)
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        print("GAN is " + args.gan_type)
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        print("GAN is " + args.gan_type)
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        print("GAN is " + args.gan_type)
        gan = infoGAN(args, SUPERVISED=False)
    elif args.gan_type == 'EBGAN':
        print("GAN is " + args.gan_type)
        gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        print("GAN is " + args.gan_type)
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        print("GAN is " + args.gan_type)
        gan = WGAN_GP(args)
    elif args.gan_type == 'DRAGAN':
        print("GAN is " + args.gan_type)
        gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        print("GAN is " + args.gan_type)
        gan = LSGAN(args)
    elif args.gan_type == 'BEGAN':
        print("GAN is " + args.gan_type)
        gan = BEGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
def main(_):
    check_dir()
    print_config()
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    run_option = tf.ConfigProto(gpu_options=gpu_options)
    with tf.Session(config=run_option) as sess:
        wgan_gp = WGAN_GP(config=FLAGS, sess=sess)
        wgan_gp.build_model()
        if FLAGS.is_training:
            wgan_gp.train_model()
        if FLAGS.is_testing:
            wgan_gp.test_model()
Example #7
0
def main(args):

    # load the wished dataset (WGAN-GP: We only train it on a subcategory, manually moved)
    dataset_path = args.dataset_path
    target_size = (args.img_size, args.img_size)
    transforms = Compose([
        Resize(target_size),
        ToTensor(),
        Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    train_dataset = ImageFolder(dataset_path, transform=transforms)
    dataloader = DataLoader(train_dataset,
                            batch_size=args.batch_size,
                            shuffle=True)
    print(len(dataloader.dataset))

    # initiative the model, relevant args: Channels, BatchSize, iterations, cuda, train
    model = WGAN_GP(args)

    # Load datasets to train and test loaders
    # Start model training
    if args.is_train is True:
        model.train(dataloader)
Example #8
0
def main():
    with tf.Session() as sess:
        gan = WGAN_GP(sess,
                      epoch=10000,
                      batch_size=16,
                      dataset_name='eyes.tfrecords',
                      checkpoint_dir='checkpoint',
                      result_dir='results',
                      log_dir='logs')
        # build graph
        gan.build_model()
        show_all_variables()
        gan.train()

        print(" [*] Training finished!")
        gan.visualize_results(20 - 1)
        print(" [*] Testing finished!")
Example #9
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    elif args.gan_type == 'LSGAN_gene':
        gan = LSGAN_gene(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    '''
    # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        gan = infoGAN(args, SUPERVISED = True)
    elif args.gan_type == 'EBGAN':
        gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    elif args.gan_type == 'DRAGAN':
        gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)
    '''

    # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")

    # visualize learned generator
    gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
Example #10
0
def main():
    """main"""

    # parse arguments

    args = parse_args()
    print('Training {},started at {}'.format(
        args.gan_type, time.asctime(time.localtime(time.time()))))

    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    # elif args.gan_type == 'CGAN':
    #     gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    # elif args.gan_type == 'infoGAN':
    #     gan = infoGAN(args, SUPERVISED=False)
    # elif args.gan_type == 'EBGAN':
    #     gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    # elif args.gan_type == 'DRAGAN':
    #     gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

        # launch the graph in a session

    # return
    gan.train()
    print('Training {},finished at {}'.format(
        args.gan_type, time.asctime(time.localtime(time.time()))))
Example #11
0
def main():
    with tf.Session() as sess:
        gan = WGAN_GP(sess,
                      epoch=10000,
                      batch_size=16,
                      dataset_name= 'eyes.tfrecords',
                      checkpoint_dir='checkpoint',
                      result_dir='results',
                      log_dir='logs')
        # build graph
        gan.build_model()
        show_all_variables()
        gan.train()

        print(" [*] Training finished!")
        gan.visualize_results(20-1)
        print(" [*] Testing finished!")
Example #12
0
def main(_):
  # Environment Setting
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  os.environ["CUDA_VISIBLE_DEVICES"] = cfg.device_id
  
  if not os.path.exists(cfg.results):
    os.makedirs(cfg.results)
  if not os.path.exists(cfg.checkpoint):
    os.makedirs(cfg.checkpoint)
  if not os.path.exists(cfg.summary_dir):
    os.makedirs(cfg.summary_dir)
  
  # Construct Networks
  net = WGAN_GP()
  data_feed = loadData(batch_size=cfg.batch_size, train_shuffle=True) # False
  profile, front = data_feed.get_train()
  net.build_up(profile, front)
  
  # Train or Test
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  with tf.Session(config=config, graph=net.graph) as sess:
    sess.run(tf.global_variables_initializer())
    
    # Start Thread
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    saver = tf.train.Saver(max_to_keep=0)  #
    if cfg.is_finetune:
      saver.restore(sess, cfg.checkpoint_ft)
      print('Load Finetuned Model Successfully!')
      
    num_batch = int(cfg.dataset_size / cfg.batch_size)
    writer = tf.summary.FileWriter(cfg.summary_dir, sess.graph)
          
    # Train by minibatch and critic
    for epoch in range(cfg.epoch):
      for step in range(num_batch):        
        # Discriminator Part
        if(step < 25 and epoch == 0 and not cfg.is_finetune):
          critic = 25
        else:
          critic = cfg.critic
        for i in range(critic):
          _ = sess.run(net.train_dis, {net.is_train:True})
        
        # Generative Part
        #_,fl,gl,dl,gen,summary = sess.run([net.train_gen,net.feature_loss,net.g_loss,
        #                                   net.d_loss,net.gen_p,net.train_summary],
        #                                   {net.is_train:True})
        _,fl,gl,dl,gen,g1,g2,g4,summary = sess.run([net.train_gen, net.feature_loss,net.g_loss,
                            net.d_loss,net.gen_p,net.grad1,net.grad2,net.grad4,net.train_summary],
                           {net.is_train:True})
        
        #print('%d-%d, Fea Loss:%.2f, D Loss:%4.1f, G Loss:%4.1f,' % (epoch, step, fl, dl, gl))
        print('%d-%d, Fea Loss:%.2f, D Loss:%4.1f, G Loss:%4.1f, g1/2/4:%.5f/%.5f/%.5f ' %  #
           (epoch, step, fl, dl, gl, g1*cfg.lambda_fea,g2,g4))                 
        
        # Save Model and Summary and Test
        if(step % cfg.save_freq == 0):
          writer.add_summary(summary, epoch*num_batch + step)
          print("Saving Model....")
          saver.save(sess, os.path.join(cfg.checkpoint, 'ck-%02d' % (epoch))) #
          
          # test
          fl, dl, gl = 0., 0., 0.
          for i in range(50): # 25791 / 16
            te_profile, te_front = data_feed.get_test_batch(cfg.batch_size)
            dl_, gl_, fl_, images = sess.run([net.d_loss,net.g_loss, net.feature_loss, net.gen_p],
                              {profile:te_profile, front:te_front, net.is_train:False}) #
            data_feed.save_images(images, epoch)
            dl += dl_; gl += gl_; fl += fl_
          print('Testing: Fea Loss:%.1f, D Loss:%.1f, G Loss:%.1f' % (fl, dl, gl))
    
    # Close Threads
    coord.request_stop()
    coord.join(threads)
Example #13
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        # declare instance for GAN
        if args.gan_type == 'GAN':
            gan = GAN(sess,
                      epoch=args.epoch,
                      batch_size=args.batch_size,
                      z_dim=args.z_dim,
                      dataset_name=args.dataset,
                      checkpoint_dir=args.checkpoint_dir,
                      result_dir=args.result_dir,
                      log_dir=args.log_dir)
        elif args.gan_type == 'CGAN':
            gan = CGAN(sess,
                       epoch=args.epoch,
                       batch_size=args.batch_size,
                       z_dim=args.z_dim,
                       dataset_name=args.dataset,
                       checkpoint_dir=args.checkpoint_dir,
                       result_dir=args.result_dir,
                       log_dir=args.log_dir)
        elif args.gan_type == 'ACGAN':
            gan = ACGAN(sess,
                        epoch=args.epoch,
                        batch_size=args.batch_size,
                        z_dim=args.z_dim,
                        dataset_name=args.dataset,
                        checkpoint_dir=args.checkpoint_dir,
                        result_dir=args.result_dir,
                        log_dir=args.log_dir)
        elif args.gan_type == 'infoGAN':
            gan = infoGAN(sess,
                          epoch=args.epoch,
                          batch_size=args.batch_size,
                          z_dim=args.z_dim,
                          dataset_name=args.dataset,
                          checkpoint_dir=args.checkpoint_dir,
                          result_dir=args.result_dir,
                          log_dir=args.log_dir)
        elif args.gan_type == 'EBGAN':
            gan = EBGAN(sess,
                        epoch=args.epoch,
                        batch_size=args.batch_size,
                        z_dim=args.z_dim,
                        dataset_name=args.dataset,
                        checkpoint_dir=args.checkpoint_dir,
                        result_dir=args.result_dir,
                        log_dir=args.log_dir)
        elif args.gan_type == 'WGAN':
            gan = WGAN(sess,
                       epoch=args.epoch,
                       batch_size=args.batch_size,
                       z_dim=args.z_dim,
                       dataset_name=args.dataset,
                       checkpoint_dir=args.checkpoint_dir,
                       result_dir=args.result_dir,
                       log_dir=args.log_dir)
        elif args.gan_type == 'WGAN_GP':
            gan = WGAN_GP(sess,
                          epoch=args.epoch,
                          batch_size=args.batch_size,
                          z_dim=args.z_dim,
                          dataset_name=args.dataset,
                          checkpoint_dir=args.checkpoint_dir,
                          result_dir=args.result_dir,
                          log_dir=args.log_dir)
        elif args.gan_type == 'DRAGAN':
            gan = DRAGAN(sess,
                         epoch=args.epoch,
                         batch_size=args.batch_size,
                         z_dim=args.z_dim,
                         dataset_name=args.dataset,
                         checkpoint_dir=args.checkpoint_dir,
                         result_dir=args.result_dir,
                         log_dir=args.log_dir)
        elif args.gan_type == 'LSGAN':
            gan = LSGAN(sess,
                        epoch=args.epoch,
                        batch_size=args.batch_size,
                        z_dim=args.z_dim,
                        dataset_name=args.dataset,
                        checkpoint_dir=args.checkpoint_dir,
                        result_dir=args.result_dir,
                        log_dir=args.log_dir)
        elif args.gan_type == 'BEGAN':
            gan = BEGAN(sess,
                        epoch=args.epoch,
                        batch_size=args.batch_size,
                        z_dim=args.z_dim,
                        dataset_name=args.dataset,
                        checkpoint_dir=args.checkpoint_dir,
                        result_dir=args.result_dir,
                        log_dir=args.log_dir)
        elif args.gan_type == 'VAE':
            gan = VAE(sess,
                      epoch=args.epoch,
                      batch_size=args.batch_size,
                      z_dim=args.z_dim,
                      dataset_name=args.dataset,
                      checkpoint_dir=args.checkpoint_dir,
                      result_dir=args.result_dir,
                      log_dir=args.log_dir)
        elif args.gan_type == 'CVAE':
            gan = CVAE(sess,
                       epoch=args.epoch,
                       batch_size=args.batch_size,
                       z_dim=args.z_dim,
                       dataset_name=args.dataset,
                       checkpoint_dir=args.checkpoint_dir,
                       result_dir=args.result_dir,
                       log_dir=args.log_dir)
        elif args.gan_type == 'VAE_GAN':
            gan = VAE_GAN(sess,
                          epoch=args.epoch,
                          batch_size=args.batch_size,
                          z_dim=args.z_dim,
                          dataset_name=args.dataset,
                          checkpoint_dir=args.checkpoint_dir,
                          result_dir=args.result_dir,
                          log_dir=args.log_dir)
        else:
            raise Exception("[!] There is no option for " + args.gan_type)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        # launch the graph in a session
        gan.train()
        print(" [*] Training finished!")

        # visualize learned generator
        gan.visualize_results(args.epoch - 1)
        print(" [*] Testing finished!")
Example #14
0
def main():
    # parse arguments
    opts = parse_args()
    if opts is None:
        exit()

        # declare instance for GAN
    if opts.gan_type == 'GAN':
        gan = GAN(opts)
    elif opts.gan_type == 'CGAN':
        gan = CGAN(opts)
    elif opts.gan_type == 'ACGAN':
        gan = ACGAN(opts)
    elif opts.gan_type == 'infoGAN':
        gan = infoGAN(opts, SUPERVISED=True)
    elif opts.gan_type == 'EBGAN':
        gan = EBGAN(opts)
    elif opts.gan_type == 'WGAN':
        gan = WGAN(opts)
    elif opts.gan_type == 'WGAN_GP':
        gan = WGAN_GP(opts)
    elif opts.gan_type == 'DRAGAN':
        gan = DRAGAN(opts)
    elif opts.gan_type == 'LSGAN':
        gan = LSGAN(opts)
    elif opts.gan_type == 'BEGAN':
        gan = BEGAN(opts)
    elif opts.gan_type == 'DRGAN':
        gan = DRGAN(opts)
    elif opts.gan_type == 'AE':
        gan = AutoEncoder(opts)
    elif opts.gan_type == 'GAN3D':
        gan = GAN3D(opts)
    elif opts.gan_type == 'VAEGAN3D':
        gan = VAEGAN3D(opts)
    elif opts.gan_type == 'DRGAN3D':
        gan = DRGAN3D(opts)
    elif opts.gan_type == 'Recog3D':
        gan = Recog3D(opts)
    elif opts.gan_type == 'Recog2D':
        gan = Recog2D(opts)
    elif opts.gan_type == 'VAEDRGAN3D':
        gan = VAEDRGAN3D(opts)
    elif opts.gan_type == 'DRcycleGAN3D':
        gan = DRcycleGAN3D(opts)
    elif opts.gan_type == 'CycleGAN3D':
        gan = CycleGAN3D(opts)
    elif opts.gan_type == 'AE3D':
        gan = AutoEncoder3D(opts)
    elif opts.gan_type == 'DRGAN2D':
        gan = DRGAN2D(opts)
    elif opts.gan_type == 'DRecon3DGAN':
        gan = DRecon3DGAN(opts)
    elif opts.gan_type == 'DRecon2DGAN':
        gan = DRecon2DGAN(opts)
    elif opts.gan_type == 'DReconVAEGAN':
        gan = DReconVAEGAN(opts)
    else:
        raise Exception("[!] There is no option for " + opts.gan_type)

    if opts.resume or len(opts.eval) > 0:
        print(" [*] Loading saved model...")
        gan.load()
        print(" [*] Loading finished!")

    # launch the graph in a session
    if len(opts.eval) == 0:
        gan.train()
        print(" [*] Training finished!")
    else:
        print(" [*] Training skipped!")

    # visualize learned generator
    if len(opts.eval) == 0:
        print(" [*] eval mode is not specified!")
    else:
        if opts.eval == 'generate':
            gan.visualize_results(opts.epoch)
        elif opts.eval == 'interp_z':
            gan.interpolate_z(opts)
        elif opts.eval == 'interp_id':
            gan.interpolate_id(opts)
        elif opts.eval == 'interp_expr':
            gan.interpolate_expr(opts)
        elif opts.eval == 'recon':
            gan.reconstruct()
        elif opts.eval == 'control_expr':
            gan.control_expr()
        else:
            gan.manual_inference(opts)
        print(" [*] Testing finished!")
Example #15
0
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'GAN':
        gan = GAN(args)
    elif args.gan_type == 'CGAN':
        gan = CGAN(args)
    elif args.gan_type == 'ACGAN':
        gan = ACGAN(args)
    elif args.gan_type == 'infoGAN':
        gan = infoGAN(args, SUPERVISED=False)
    elif args.gan_type == 'EBGAN':
        gan = EBGAN(args)
    elif args.gan_type == 'WGAN':
        gan = WGAN(args)
    elif args.gan_type == 'WGAN_GP':
        gan = WGAN_GP(args)
    elif args.gan_type == 'DRAGAN':
        gan = DRAGAN(args)
    elif args.gan_type == 'LSGAN':
        gan = LSGAN(args)
    elif args.gan_type == 'BEGAN':
        gan = BEGAN(args)
    elif args.gan_type == 'TOGAN':
        gan = TOGAN(args)
    elif args.gan_type == 'CVAE':
        gan = CVAE(args)
    elif args.gan_type == None:
        pass
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

    if args.use_fake_data:
        fakedata = gan.load()
    else:
        fakedata = None

    if args.clf_type == 'clf':
        clf = CLF(args, fakedata)
        clf.load()
        # clf.train()

    else:
        gan.train()
        # gan()
        # gan.load()

    # launch the graph in a session
    # clf.train()
    print(" [*] Training finished!")

    # visualize learned generator
    # gan.visualize_results(args.epoch)
    print(" [*] Testing finished!")
Example #16
0
def main(_):
    # Environment Setting
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
    if not os.path.exists(cfg.results):
        os.mkdir(cfg.results)

    # Construct Networks
    # Change this line if 'LSGAN' or 'WGAN'
    net = WGAN_GP()

    # Train and Test
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config, graph=net.graph) as sess:
        sess.run(tf.global_variables_initializer())

        # Start Thread
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        saver = tf.train.Saver(max_to_keep=0)  #
        if cfg.is_finetune:
            saver.restore(sess, cfg.model_path)
            print('Load Finetuned Model Successfully!')

        num_batch = int(cfg.dataset_size / cfg.batch_size)
        writer = tf.summary.FileWriter(cfg.summary_dir, sess.graph)

        # Train by minibatch and critic
        for epoch in range(cfg.epoch):
            for step in range(num_batch):
                # Discriminator Part
                if (step < 25 and epoch == 0 and not cfg.is_finetune):
                    critic = 25
                else:
                    critic = cfg.critic
                for i in range(critic):
                    # add 'net.clip_D' into ops if 'LSGAN' or 'WGAN'
                    _ = sess.run(net.train_dis,
                                 {net.is_train: True})  # net.clip_D

                # Generative Part
                _, fl, gl, dl, gen, summary = sess.run([
                    net.train_gen, net.feature_loss, net.g_loss, net.d_loss,
                    net.gen_p, net.train_summary
                ], {net.is_train: True})
                writer.add_summary(summary, epoch * num_batch + step)
                print('%d-%d, Fea Loss:%.2f, D Loss:%4.1f, G Loss:%4.1f'
                      %  #g1/2/3:%.5f/%.5f/%.5f 
                      (epoch, step, fl, dl, gl))  #g1*cfg.lambda_fea,g2,g4

                # Test Part
                if step % cfg.test_sum_freq == 0:
                    net.data_feed.save_train(gen)
                    fl, dl, gl = 0., 0., 0.
                    for i in range(test_num):
                        te_profile, te_front = net.data_feed.get_test_batch(
                            cfg.batch_size)
                        dl_, gl_, fl_, images = sess.run([net.d_loss,net.g_loss,\
                                                          net.feature_loss, net.gen_p],
                                                          {net.profile:te_profile, net.front:te_front, net.is_train:False}) #
                        net.data_feed.save_images(images, epoch)
                        dl += dl_
                        gl += gl_
                        fl += fl_
                    print('Testing: Fea Loss:%.1f, D Loss:%.1f, G Loss:%.1f' %
                          (fl / test_num, dl / test_num, gl / test_num))

                # Save Model
                if (step != 0 and step % cfg.save_freq == 0):
                    print("Saving Model....")
                    saver.save(sess, cfg.logdir + '-%02d' % (epoch))  #

        # Close Threads
        coord.request_stop()
        coord.join(threads)
Example #17
0
from utils import check_folder
import tensorflow as tf
import argparse

desc = "dimension"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--dim', type=int, help='input dimension')
args = parser.parse_args()
dim = args.dim

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

    gan = WGAN_GP(sess,
                  epoch=20,
                  batch_size=128,
                  z_dim=dim,
                  dataset_name='fashion-mnist',
                  checkpoint_dir='checkpoints',
                  result_dir='results',
                  log_dir='logs')
    if gan is None:
        raise Exception("[!] There is no option for " + args.gan_type)

    # build graph
    gan.build_model()

    # show network architecture
    show_all_variables()

    # launch the graph in a session
    gan.train()
    print(" [*] Training finished!")