コード例 #1
0
def main(config):
    num_step = config.num_step
    data_loader = LSUNCatLoader(get_lsun_patterns(config.data_dir),
                                num_workers=4,
                                actions=lsun_process_actions())

    names = []
    fobjs = []
    try:
        data_loader.start_fetch()
        print("generating images...")
        for _ in xrange(num_step):
            fd, name = tempfile.mkstemp(suffix=".npy")
            fobj = os.fdopen(fd, "wb+")
            names.append(name)
            fobjs.append(fobj)
            image_arr = data_loader.next_batch(config.batch_size)[0]
            np.save(fobj, image_arr, allow_pickle=False)
            fobj.close()

        mean_score, std_score = get_resnet18_score(images_iter(names),
                                                config.model_path,
                                                batch_size=100,
                                                split=10)

        print("mean = %.4f, std = %.4f." % (mean_score, std_score))

        if config.save_path is not None:
            with open(config.save_path, "wb") as f:
                cPickle.dump(dict(batch_size=config.batch_size,
                                  scores=dict(mean=mean_score, std=std_score)), f)
    finally:
        data_loader.stop_fetch()
        for name in names:
            os.unlink(name)
        for fobj in fobjs:
            fobj.close()
コード例 #2
0
        eval_losses.append(eval_loss)

    sess.close()
    print("accuracy:", np.mean(eval_losses))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("val_data_dir", metavar="VALDATADIR")
    parser.add_argument("model_path", metavar="MODELPATH")
    parser.add_argument("--batch-size",
                        dest="batch_size",
                        type=int,
                        default=100)
    parser.add_argument("--dim", dest="dim", default=64, type=int)

    config = parser.parse_args()

    print("config: %r" % config)

    eval_data_loader = LSUNCatLoader(get_lsun_patterns(config.val_data_dir),
                                     num_workers=2,
                                     actions=lsun_process_actions())

    try:
        eval_data_loader.start_fetch()
        run_task(config, eval_data_loader, classifier_forward,
                 tf.train.AdamOptimizer())
    finally:
        eval_data_loader.stop_fetch()
コード例 #3
0
    if config.save_path is not None:
        fobj = open(config.save_path, "w")
    else:
        fobj = None

    for params in chain(GRID1, GRID2):
        for key, value in params.items():
            setattr(config, key, value)
        name = NAME_STYLE % params
        print("config: %r" % config)
        print("resetting environment...")
        tf.reset_default_graph()

        eval_data_loader = LSUNCatLoader(get_lsun_patterns(
            config.eval_data_dir),
                                         num_workers=5,
                                         block_size=20,
                                         actions=lsun_process_actions())
        try:
            eval_data_loader.start_fetch()
            mean_accuracy = run_task_eval(config,
                                          eval_data_loader,
                                          image_classifier_forward,
                                          model_dir=os.path.join(
                                              model_dir, name + "_models"))
            if fobj is not None:
                fobj.write("%s: %.4f\n" % (name, mean_accuracy))
                print("f**k")

        finally:
            eval_data_loader.stop_fetch()
コード例 #4
0
ファイル: dp_lsun_5cat.py プロジェクト: sunnerzs/dpgan-1
    parser.add_argument("--sample-dir", dest="sample_dir")

    config = parser.parse_args()

    np.random.seed()
    if config.enable_accounting:
        config.sigma = np.sqrt(
            2.0 * np.log(1.25 / config.delta)) / config.epsilon
        print("Now with new sigma: %.4f" % config.sigma)

    if config.image_size == 64:
        patterns = get_lsun_patterns(config.data_dir)
        print(patterns)
        data_loader = LSUNCatLoader(patterns,
                                    num_workers=4,
                                    actions=lsun_process_actions(),
                                    block_size=16,
                                    max_blocks=256)
        data_loader.start_fetch()
        generator_forward = d64_resnet_dcgan.generator_forward
        discriminator_forward = d64_resnet_dcgan.discriminator_forward
    else:
        raise NotImplementedError("Unsupported image size %d." %
                                  config.image_size)

    if config.enable_accounting:
        accountant = GaussianMomentsAccountant(data_loader.num_steps(1),
                                               config.moment)
        if config.log_path:
            open(config.log_path, "w").close()
    else:
コード例 #5
0
        name = NAME_STYLE % params
        if save_dir is not None:
            config.save_dir = os.path.join(save_dir, name + "_models")
            os.makedirs(config.save_dir, exist_ok=True)
        if log_dir is not None:
            config.log_path = os.path.join(log_dir, name + ".log")

        print("config: %r" % config)
        print("resetting environment...")
        tf.reset_default_graph()

        train_data_loader = LSUNCatLoader(get_lsun_patterns(
            config.train_data_dir),
                                          num_workers=10,
                                          block_size=20,
                                          max_blocks=500,
                                          max_numbers=None,
                                          actions=lsun_process_actions(),
                                          public_num=config.public_num,
                                          public_seed=1024)
        eval_data_loader = LSUNCatLoader(get_lsun_patterns(
            config.eval_data_dir),
                                         block_size=20,
                                         max_numbers=None,
                                         actions=lsun_process_actions(),
                                         num_workers=4)

        try:
            train_data_loader.start_fetch()
            eval_data_loader.start_fetch()