Ejemplo n.º 1
0
def main(config):
    num_step = config.num_step
    data_loader = LSUNCatLoader(get_lsun_patterns(config.data_dir),
                                num_workers=4,
                                actions=lsun_process_actions())

    names = []
    fobjs = []
    try:
        data_loader.start_fetch()
        print("generating images...")
        for _ in xrange(num_step):
            fd, name = tempfile.mkstemp(suffix=".npy")
            fobj = os.fdopen(fd, "wb+")
            names.append(name)
            fobjs.append(fobj)
            image_arr = data_loader.next_batch(config.batch_size)[0]
            np.save(fobj, image_arr, allow_pickle=False)
            fobj.close()

        mean_score, std_score = get_resnet18_score(images_iter(names),
                                                config.model_path,
                                                batch_size=100,
                                                split=10)

        print("mean = %.4f, std = %.4f." % (mean_score, std_score))

        if config.save_path is not None:
            with open(config.save_path, "wb") as f:
                cPickle.dump(dict(batch_size=config.batch_size,
                                  scores=dict(mean=mean_score, std=std_score)), f)
    finally:
        data_loader.stop_fetch()
        for name in names:
            os.unlink(name)
        for fobj in fobjs:
            fobj.close()
Ejemplo n.º 2
0
        eval_losses.append(eval_loss)

    sess.close()
    print("accuracy:", np.mean(eval_losses))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("val_data_dir", metavar="VALDATADIR")
    parser.add_argument("model_path", metavar="MODELPATH")
    parser.add_argument("--batch-size",
                        dest="batch_size",
                        type=int,
                        default=100)
    parser.add_argument("--dim", dest="dim", default=64, type=int)

    config = parser.parse_args()

    print("config: %r" % config)

    eval_data_loader = LSUNCatLoader(get_lsun_patterns(config.val_data_dir),
                                     num_workers=2,
                                     actions=lsun_process_actions())

    try:
        eval_data_loader.start_fetch()
        run_task(config, eval_data_loader, classifier_forward,
                 tf.train.AdamOptimizer())
    finally:
        eval_data_loader.stop_fetch()
Ejemplo n.º 3
0
def main(config):
    config.dim = 64
    num_step = config.num_step
    input_labels = tf.placeholder(tf.float32, [None, 1], name="input_labels")
    input_images = tf.placeholder(tf.float32, [None, 64, 64, 3],
                                  name="input_images")
    logits = classifier_forward(config, input_images, name="classifier")
    probs = tf.nn.sigmoid(logits)

    loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=input_labels,
                                                logits=logits))
    classifier_vars = [
        var for var in tf.global_variables()
        if var.name.startswith("classifier")
    ]
    train_step = tf.train.AdamOptimizer().minimize(
        loss,
        var_list=[
            var for var in classifier_vars if var.name.startswith("classifier")
        ])

    names = []
    fobjs = []

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # train classifier
    data_loader = LSUNLoader(config.data_dir,
                             block_size=20,
                             num_workers=6,
                             actions=lsun_process_actions())
    fake_data_loader = LSUNLoader(config.data_dir,
                                  block_size=20,
                                  num_workers=6,
                                  actions=lsun_process_actions())
    try:
        tflearn.is_training(True, sess)
        fake_data_loader.start_fetch()
        data_loader.start_fetch()
        bar = trange(config.x)
        for _ in bar:
            real_images, _ = data_loader.next_batch(config.batch_size)
            real_labels = np.full([len(real_images), 1], 1, np.float32)

            generated_images, _ = fake_data_loader.next_batch(
                config.batch_size)
            generated_labels = np.full([len(generated_images), 1], 0,
                                       np.float32)

            images = np.concatenate([real_images, generated_images], axis=0)
            labels = np.concatenate([real_labels, generated_labels], axis=0)

            indices = np.random.permutation(
                np.arange(0, len(images), dtype=np.int64))
            images = images[indices]
            labels = labels[indices]

            loss_value, _ = sess.run([loss, train_step],
                                     feed_dict={
                                         input_labels: labels,
                                         input_images: images
                                     })
            bar.set_description("loss: %.4f" % loss_value)

    finally:
        data_loader.stop_fetch()
        fake_data_loader.stop_fetch()

    fake_data_loader = LSUNLoader(config.data_dir,
                                  block_size=20,
                                  num_workers=6,
                                  actions=lsun_process_actions())
    try:
        fake_data_loader.start_fetch()
        tflearn.is_training(False, sess)
        for _ in trange(num_step):
            fd, name = tempfile.mkstemp(suffix=".npy")
            fobj = os.fdopen(fd, "wb+")
            names.append(name)
            fobjs.append(fobj)
            image_arr = fake_data_loader.next_batch(config.batch_size)[0]
            np.save(fobj, image_arr, allow_pickle=False)
            fobj.close()

        mean_score, std_score = get_quality_score(sess,
                                                  input_images,
                                                  probs,
                                                  images_iter(names),
                                                  batch_size=100,
                                                  split=10)

        print("mean = %.4f, std = %.4f." % (mean_score, std_score))

        if config.save_path is not None:
            with open(config.save_path, "wb") as f:
                cPickle.dump(
                    dict(batch_size=config.batch_size,
                         scores=dict(mean=mean_score, std=std_score)), f)
    finally:
        for name in names:
            os.unlink(name)
        for fobj in fobjs:
            fobj.close()
        fake_data_loader.stop_fetch()
Ejemplo n.º 4
0
    parser.add_argument("--beta2", default=0.9, type=float, dest="beta2")
    parser.add_argument("data_dir", metavar="DATADIR")
    parser.add_argument("--image-size", default=64, type=int, dest="image_size")
    parser.add_argument("--adaptive-rate", dest="adaptive_rate", action="store_true")
    parser.add_argument("--no-noise", dest="no_noise", action="store_true")
    parser.add_argument("--sample-dir", dest="sample_dir")

    config = parser.parse_args()

    np.random.seed()
    if config.enable_accounting:
        config.sigma = np.sqrt(2.0 * np.log(1.25 / config.delta)) / config.epsilon
        print("Now with new sigma: %.4f" % config.sigma)

    if config.image_size == 64:
        data_loader = LSUNLoader(config.data_dir, num_workers=4, actions=lsun_process_actions())
        data_loader.start_fetch()
        generator_forward = d64_resnet_dcgan.generator_forward
        discriminator_forward = d64_resnet_dcgan.discriminator_forward
    else:
        raise NotImplementedError("Unsupported image size %d." % config.image_size)

    if config.enable_accounting:
        accountant = GaussianMomentsAccountant(data_loader.num_steps(1), config.moment)
        if config.log_path:
            open(config.log_path, "w").close()
    else:
        accountant = None

    if config.adaptive_rate:
        lr = tf.placeholder(tf.float32, shape=())
Ejemplo n.º 5
0
    parser.add_argument("--image-size", default=64, type=int, dest="image_size")
    parser.add_argument("--adaptive-rate", dest="adaptive_rate", action="store_true")
    parser.add_argument("--no-noise", dest="no_noise", action="store_true")
    parser.add_argument("--sample-dir", dest="sample_dir")

    config = parser.parse_args()

    np.random.seed()
    if config.enable_accounting:
        config.sigma = np.sqrt(2.0 * np.log(1.25 / config.delta)) / config.epsilon
        print("Now with new sigma: %.4f" % config.sigma)

    if config.image_size == 64:
        patterns = get_lsun_patterns(config.data_dir)
        print(patterns)
        data_loader = LSUNCatLoader(patterns, num_workers=4, actions=lsun_process_actions(),
                                    block_size=16, max_blocks=256)
        data_loader.start_fetch()
        generator_forward = d64_resnet_dcgan.generator_forward
        discriminator_forward = d64_resnet_dcgan.discriminator_forward
    else:
        raise NotImplementedError("Unsupported image size %d." % config.image_size)

    if config.enable_accounting:
        accountant = GaussianMomentsAccountant(data_loader.num_steps(1), config.moment)
        if config.log_path:
            open(config.log_path, "w").close()
    else:
        accountant = None

    if config.adaptive_rate: