Esempio n. 1
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    cfg = config_from_yaml(FLAGS.cfg)

    if not os.path.exists(cfg.CHECKPOINT_DIR):
        os.makedirs(cfg.CHECKPOINT_DIR)
    if not os.path.exists(cfg.LOGS_DIR):
        os.makedirs(cfg.LOGS_DIR)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    datadir = cfg.DATASET_DIR
    dataset = TextDataset(datadir, 299)

    # We train inception on the test dataset which contains completely other classes from the train dataset
    # (used in GAN training). This is needed for a correct evaluation of the Inception/FID score.
    filename_test = '%s/test' % datadir
    dataset.test = dataset.get_data(filename_test)

    with tf.Session(config=run_config) as sess:
        if cfg.TRAIN.FLAG:
            stage_i_trainer = InceptionTrainer(
                sess=sess,
                dataset=dataset,
                cfg=cfg,
            )
            stage_i_trainer.train()
        else:
            pass
Esempio n. 2
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    cfg_stage_i = config_from_yaml(FLAGS.cfg_stage_I)
    cfg = config_from_yaml(FLAGS.cfg_stage_II)

    if not os.path.exists(cfg.CHECKPOINT_DIR):
        os.makedirs(cfg.CHECKPOINT_DIR)
    if not os.path.exists(cfg.SAMPLE_DIR):
        os.makedirs(cfg.SAMPLE_DIR)
    if not os.path.exists(cfg.LOGS_DIR):
        os.makedirs(cfg.LOGS_DIR)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    datadir = cfg.DATASET_DIR
    dataset = TextDataset(datadir, 256)

    filename_test = '%s/test' % datadir
    dataset.test = dataset.get_data(filename_test)

    filename_train = '%s/train' % datadir
    dataset.train = dataset.get_data(filename_train)

    with tf.Session(config=run_config) as sess:
        if cfg.EVAL.FLAG:
            stage_i = ConditionalGanStageI(cfg_stage_i, build_model=False)
            stage_ii = ConditionalGan(stage_i, cfg, build_model=False)
            stage_ii_eval = StageIIEval(
                sess=sess,
                model=stage_ii,
                dataset=dataset,
                cfg=cfg,
            )
            stage_ii_eval.evaluate_inception()

        elif cfg.TRAIN.FLAG:
            stage_i = ConditionalGanStageI(cfg_stage_i, build_model=False)
            stage_ii = ConditionalGan(stage_i, cfg)
            show_all_variables()
            stage_ii_trainer = ConditionalGanTrainer(
                sess=sess,
                model=stage_ii,
                dataset=dataset,
                cfg=cfg,
                cfg_stage_i=cfg_stage_i,
            )
            stage_ii_trainer.train()
        else:
            stage_i = ConditionalGanStageI(cfg_stage_i, build_model=False)
            stage_ii = ConditionalGan(stage_i, cfg, build_model=False)
            stage_ii_eval = StageIIVisualizer(
                sess=sess,
                model=stage_ii,
                dataset=dataset,
                cfg=cfg,
            )
            stage_ii_eval.visualize()
Esempio n. 3
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    cfg = config_from_yaml(FLAGS.cfg)

    if not os.path.exists(cfg.CHECKPOINT_DIR):
        os.makedirs(cfg.CHECKPOINT_DIR)
    if not os.path.exists(cfg.SAMPLE_DIR):
        os.makedirs(cfg.SAMPLE_DIR)
    if not os.path.exists(cfg.LOGS_DIR):
        os.makedirs(cfg.LOGS_DIR)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    datadir = cfg.DATASET_DIR
    dataset = TextDataset(datadir, 64)
    # dataset = TextDataset(datadir, 512)

    filename_test = '%s/test' % datadir
    dataset._test = dataset.get_data(filename_test)

    filename_train = '%s/train' % datadir
    dataset.train = dataset.get_data(filename_train)

    with tf.Session(config=run_config) as sess:

        if cfg.EVAL.FLAG:
            gancls = GanCls(cfg, build_model=False)
            eval = GanClsEval(
                sess=sess,
                model=gancls,
                dataset=dataset,
                cfg=cfg)
            eval.evaluate_inception()
        elif cfg.TRAIN.FLAG:
            gancls = GanCls(cfg)
            show_all_variables()
            gancls_trainer = GanClsTrainer(
                sess=sess,
                model=gancls,
                dataset=dataset,
                cfg=cfg,
            )
            gancls_trainer.train()
        else:
            gancls = GanCls(cfg, build_model=False)
            gancls_visualiser = GanClsVisualizer(
                sess=sess,
                model=gancls,
                dataset=dataset,
                config=cfg,
            )
            gancls_visualiser.visualize()
Esempio n. 4
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    cfg = config_from_yaml(FLAGS.cfg)

    if not os.path.exists(cfg.CHECKPOINT_DIR):
        os.makedirs(cfg.CHECKPOINT_DIR)
    if not os.path.exists(cfg.SAMPLE_DIR):
        os.makedirs(cfg.SAMPLE_DIR)
    if not os.path.exists(cfg.LOGS_DIR):
        os.makedirs(cfg.LOGS_DIR)

    run_config = tf.ConfigProto()
    run_config.gpu_options.allow_growth = True

    datadir = cfg.DATASET_DIR
    dataset = TextDataset(datadir, 1)

    filename_test = '%s/test' % datadir
    dataset._test = dataset.get_data(filename_test)

    filename_train = '%s/train' % datadir
    dataset.train = dataset.get_data(filename_train)

    with tf.Session(config=run_config) as sess:
        stage_i = ConditionalGan(cfg)
        show_all_variables()

        if cfg.TRAIN.FLAG:
            stage_i_trainer = ConditionalGanTrainer(
                sess=sess,
                model=stage_i,
                dataset=dataset,
                cfg=cfg,
            )
            stage_i_trainer.train()
        else:
            pass
Esempio n. 5
0
            logs_dir = os.path.join(cfg.LOGS_DIR, 'stage%d/' % stage[i])

        if not os.path.exists(pggan_checkpoint_dir_write):
            os.makedirs(pggan_checkpoint_dir_write)
        if not os.path.exists(sample_path):
            os.makedirs(sample_path)
        if not os.path.exists(logs_dir):
            os.makedirs(logs_dir)
        if not os.path.exists(pggan_checkpoint_dir_read):
            os.makedirs(pggan_checkpoint_dir_read)

        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True

        datadir = cfg.DATASET_DIR
        dataset = TextDataset(datadir, cfg.MODEL.SIZES[stage[i] - 1])

        filename_test = '%s/test' % datadir
        dataset.test = dataset.get_data(filename_test)

        filename_train = '%s/train' % datadir
        dataset.train = dataset.get_data(filename_train)

        pggan = PGGAN(cfg,
                      batch_size=batch_size,
                      steps=max_iters,
                      check_dir_write=pggan_checkpoint_dir_write,
                      check_dir_read=pggan_checkpoint_dir_read,
                      dataset=dataset,
                      sample_path=sample_path,
                      log_dir=logs_dir,