Exemplo n.º 1
0
def main():
    root_log_dir = "./logs/"
    mkdir_p(root_log_dir)
    batch_size = FLAGS.batch_size
    max_iters = FLAGS.max_iters
    sample_size = 256
    GAN_learn_rate = FLAGS.learn_rate

    OPER_FLAG = FLAGS.OPER_FLAG
    data_In = CelebA(FLAGS.path)
    #print ("the num of dataset", len(data_In.image_list))

    if OPER_FLAG == 0:
        r_fl = 5
        test_list = [sys.argv[1]]
#'image_03902', 'image_06751', 'image_06069',
       #'image_05211', 'image_05757', 'image_05758',
       #'image_05105', 'image_03877', 'image_04325',
       #'image_05173', 'image_06667', 'image_03133',
       #'image_06625', 'image_06757', 'image_04065',
       #'image_03155'
        t = False
	f = open("captions_all.txt","a")
	f_c = open(sys.argv[2],"r")
	f.write(sys.argv[1]+'\n')
        for line in f_c:
            f.write(line)
	f.write("----\n")
	f.close()
	f_c.close()
        pggan_checkpoint_dir_write = "./model_flowers_test/"
        sample_path = "./PGGanFlowers/sample_test/"
        mkdir_p(pggan_checkpoint_dir_write)
        mkdir_p(sample_path)
        pggan_checkpoint_dir_read = "./model_flowers_{}/{}/".format(OPER_FLAG, r_fl)

        pggan = PGGAN(batch_size=batch_size, max_iters=max_iters,
                      model_path=pggan_checkpoint_dir_write, read_model_path=pggan_checkpoint_dir_read,
                      data=data_In, sample_size=sample_size,
                      sample_path=sample_path, log_dir=root_log_dir, learn_rate=GAN_learn_rate, PG= r_fl, t=t)

        pggan.build_model_PGGan()
        pggan.test(test_list,int(sys.argv[3]))
Exemplo n.º 2
0
        for i in range(FLAGS.flag):

            t = False if (i % 2 == 0) else True
            pggan_checkpoint_dir_write = "./PGGanCeleba{}/model_pggan_{}/{}/".format(
                OPER_NAME, OPER_FLAG, fl[i])
            sample_path = "./PGGanCeleba{}/{}/sample_{}_{}".format(
                OPER_NAME, FLAGS.OPER_FLAG, fl[i], t)
            mkdir_p(pggan_checkpoint_dir_write)
            mkdir_p(sample_path)
            pggan_checkpoint_dir_read = "./PGGanCeleba{}/model_pggan_{}/{}/".format(
                OPER_NAME, OPER_FLAG, r_fl[i])

            pggan = PGGAN(batch_size=FLAGS.batch_size,
                          max_iters=FLAGS.max_iters,
                          model_path=pggan_checkpoint_dir_write,
                          read_model_path=pggan_checkpoint_dir_read,
                          data=data_In,
                          sample_size=FLAGS.sample_size,
                          sample_path=sample_path,
                          log_dir=root_log_dir,
                          learn_rate=FLAGS.learn_rate,
                          lam_gp=FLAGS.lam_gp,
                          lam_eps=FLAGS.lam_eps,
                          PG=fl[i],
                          t=t,
                          use_wscale=FLAGS.use_wscale)

            pggan.build_model_PGGan()
            pggan.train()
Exemplo n.º 3
0
    return parser.parse_args()


if __name__ == "__main__":
    args = arg_parse()

    args.save_dir = "%s/outs/%s" % (os.getcwd(), args.save_dir)
    if os.path.exists(args.save_dir) is False:
        os.mkdir(args.save_dir)

    CUDA = True if torch.cuda.is_available() else False

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    torch_device = torch.device("cuda") if CUDA else torch.device('cpu')

    data_path = "../../data/celeba-"  # resolution string will be concatenated in ScalableLoader
    loader = ScalableLoader(data_path,
                            shuffle=True,
                            drop_last=True,
                            num_workers=args.cpus,
                            shuffled_cycle=True)

    g = nn.DataParallel(Generator()).to(torch_device)
    d = nn.DataParallel(Discriminator()).to(torch_device)

    tensorboard = TensorboardLogger("%s/tb" % (args.save_dir))

    pggan = PGGAN(args, g, d, loader, torch_device, args.loss, tensorboard)
    pggan.train()
Exemplo n.º 4
0
#         r_fl = [1,1,2,2,3,3,4]
        fl = [4]
        r_fl = [4]
        
        for i in range(int(FLAGS.flag)):
            t = False if (i % 2 == 0) else True
#             t = False
            pggan_checkpoint_dir_write = "./model_pggan_{}/{}/".format(OPER_FLAG, fl[i])
            sample_path = "./PGGanCifar10/{}/sample_{}_{}".format(FLAGS.OPER_FLAG, fl[i], t)
            mkdir_p(pggan_checkpoint_dir_write)
            mkdir_p(sample_path)
            pggan_checkpoint_dir_read = "./model_pggan_{}/{}/".format(OPER_FLAG, r_fl[i])

            pggan = PGGAN(batch_size=batch_size, max_iters=max_iters,
                            model_path=pggan_checkpoint_dir_write, read_model_path=pggan_checkpoint_dir_read,
                            data=data_In, sample_size=sample_size,
                            sample_path=sample_path, log_dir=root_log_dir, learn_rate=GAN_learn_rate, PG= fl[i],
                            t=t)

            pggan.build_model_PGGan()
            pggan.train()









Exemplo n.º 5
0
            sample_path = "./output/{}/{}/sample_{}_{}".format(
                FLAGS.OPER_NAME, FLAGS.OPER_FLAG, fl[i], t)
            mkdir_p(pggan_checkpoint_dir_write)
            mkdir_p(sample_path)
            pggan_checkpoint_dir_read = "./output/{}/model_pggan_{}/{}/".format(
                FLAGS.OPER_NAME, FLAGS.OPER_FLAG, r_fl[i])

            pggan = PGGAN(oper_name=FLAGS.OPER_NAME,
                          batch_size=FLAGS.batch_size,
                          max_iters=FLAGS.max_iters,
                          model_path=pggan_checkpoint_dir_write,
                          read_model_path=pggan_checkpoint_dir_read,
                          data=data_In,
                          sample_size=FLAGS.sample_size,
                          sample_path=sample_path,
                          log_dir=root_log_dir,
                          learn_rate=FLAGS.learn_rate,
                          lam_gp=FLAGS.lam_gp,
                          lam_eps=FLAGS.lam_eps,
                          PG=fl[i],
                          trans=t,
                          use_wscale=FLAGS.use_wscale,
                          is_celeba=FLAGS.celeba,
                          step_by_save_sample=FLAGS.step_by_save_sample,
                          step_by_save_weights=FLAGS.step_by_save_weights)

            pggan.build_model_PGGan()

            start_time = datetime.datetime.now()
            pggan.train()
            end_time = datetime.datetime.now()
            pggan.f_logger.write('start_time:{}'.format(str(start_time)) +
Exemplo n.º 6
0
        sample_path = "./output/{}/{}/sample_{}_{}".format(
            FLAGS.OPER_NAME, 0, 7, t)
        sample_path22 = "./output/{}/{}/Indiaviual_sample_{}_{}".format(
            FLAGS.OPER_NAME, FLAGS.OPER_FLAG, 7, t)
        mkdir_p(pggan_checkpoint_dir_write)
        mkdir_p(sample_path)
        mkdir_p(sample_path22)
        pggan_checkpoint_dir_read = "./output/{}/model_pggan_{}/{}/".format(
            FLAGS.OPER_NAME, FLAGS.OPER_FLAG, 7)
        print(pggan_checkpoint_dir_read)
        #pdb.set_trace()
        pggan = PGGAN(batch_size=FLAGS.batch_size,
                      max_iters=1,
                      model_path=pggan_checkpoint_dir_write,
                      read_model_path=pggan_checkpoint_dir_read,
                      data=data_In,
                      sample_size=FLAGS.sample_size,
                      sample_path=sample_path,
                      sample_path2=sample_path22,
                      log_dir=root_log_dir,
                      learn_rate=FLAGS.learn_rate,
                      lam_gp=FLAGS.lam_gp,
                      lam_eps=FLAGS.lam_eps,
                      PG=7,
                      t=t,
                      use_wscale=FLAGS.use_wscale,
                      is_celeba=FLAGS.celeba)

        pggan.build_model_PGGan()
        pggan.test()