Ejemplo n.º 1
0
def generate_imagenet(train_dir):
    max_real_batch_size = 128
    num_batches_in_shard = args.eval_batch_size//max_real_batch_size
    images, labels = init_generator(max_real_batch_size, denormalize=True)
    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)
    tf.get_default_graph().finalize()
    split = 'gan_100_'+args.run_name

    tfrecord_root = os.path.join(DATASETS, args.dataset)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        sess.run(init_assign_op, feed_dict=init_feed_dict)

        num_shards = args.num_generated_batches
        for shard in range(num_shards):
            output_file = os.path.join(tfrecord_root,
                                       '%s-%.5d-of-%.5d' % (split, shard, num_shards))

            def _generate_fake_images(x):
                return sess.run([images, labels])

            batch = split_apply_concat(
                np.zeros(args.eval_batch_size),
                _generate_fake_images, num_batches_in_shard, num_outputs=2)
            imagenet.convert_to_tfrecord(batch, output_file)
Ejemplo n.º 2
0
def generate(train_dir, suffix=""):
    log.info("Generating %i batches using suffix %s", args.num_generated_batches, suffix)
    images, labels = init_generator(args.eval_batch_size, denormalize=True)
    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)
    tf.get_default_graph().finalize()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        sess.run(init_assign_op, feed_dict=init_feed_dict)

        def _generate_fake_images(x):
            return sess.run([images, labels])

        data, gt = split_apply_concat(
            np.zeros(args.num_generated_batches*args.eval_batch_size),
            _generate_fake_images, args.num_generated_batches, num_outputs=2)

    split_training_mode = args.total_class_splits > 0
    if split_training_mode:
        assert args.num_classes % args.total_class_splits == 0
        split_sz = args.num_classes // args.total_class_splits
        classes_a = args.active_split_num * split_sz
        log.info("Split training mode for generation: adding %i to all labels", classes_a)
        gt += classes_a

    assert len(data) == len(gt)
    assert len(gt) == args.num_generated_batches*args.eval_batch_size
    data = data.astype(np.uint8)

    np.save(os.path.join(DATASETS, args.dataset, "X_gan_100_%s.npy" % (args.run_name+suffix)), data)
    if not args.unconditional:
        np.save(os.path.join(DATASETS, args.dataset, "Y_gan_100_%s.npy" % (args.run_name+suffix)), gt)
Ejemplo n.º 3
0
def build_predictor(train_dir, cfg, images):
    log.info("Loading classifier network from the following config: %s",
             str(cfg))
    model_params = {
        'resnet_size': cfg.resnet_size,
        'data_format': 'channels_first',
        'batch_size': cfg.evaluation.batch_size,
        'num_classes': cfg.dataset.num_classes,
    }

    with tf.variable_scope("resnet", reuse=None):
        model = cfg.network_fn(images, None, tf.estimator.ModeKeys.PREDICT,
                               model_params)
    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)

    def init_classifier_fn(sess):
        sess.run(init_assign_op, feed_dict=init_feed_dict)

    return model, init_classifier_fn
Ejemplo n.º 4
0
def swd(train_dir):
    bs = 8192

    fake_images, _ = init_generator(bs//32, denormalize=True)
    real_images, _ = cifar.load_cifar("train", bs, normalize=False,
                                      return_numpy=True, dataset=args.dataset)
    real_images = real_images[:bs]
    swd = Sliced_Wasserstein_Scorer(32, 16, 32)
    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)
    tf.get_default_graph().finalize()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        sess.run(init_assign_op, feed_dict=init_feed_dict)

        def _generate_images(x):
            return sess.run(fake_images)
        fake_images_batch = split_apply_concat(np.arange(bs), _generate_images, 32)
        swd_scores = swd.calc_sliced_wasserstein_scores(real_images, fake_images_batch)
        log.info("SWD scores: %s" % (swd_scores,))
        scaled_swd = 10**3 * np.array(swd_scores)
        log.info("SWD scores * 10^3: %s", str(scaled_swd))
Ejemplo n.º 5
0
def evaluate_classifier(train_dir, cfg):
    log.info("Loading classifier network from the following config: %s",
             str(cfg))
    model_params = {
        'resnet_size': cfg.resnet_size,
        'data_format': 'channels_first',
        'batch_size': cfg.evaluation.batch_size,
        'num_classes': cfg.dataset.num_classes,
    }
    images, labels, iter_fn = matcher.load_dataset(cfg.evaluation.split,
                                                   cfg.training.batch_size,
                                                   cfg.dataset.name,
                                                   cfg.dataset.image_size,
                                                   augmentation=False,
                                                   shuffle=False,
                                                   normalize=True,
                                                   onehot=True)

    with tf.variable_scope("resnet", reuse=None):
        model = cfg.network_fn(images, labels, tf.estimator.ModeKeys.EVAL,
                               model_params)
    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)
    local_init_op = tf.local_variables_initializer()
    tf.get_default_graph().finalize()

    log.info("Creating the session...")
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        sess.run(init_assign_op, feed_dict=init_feed_dict)
        iter_fn(sess)
        sess.run(local_init_op)  # reset counters in metrics
        num_test_batches = cfg.evaluation.test_set_size // cfg.evaluation.batch_size
        for i in range(num_test_batches):
            eval_acc = sess.run(model.eval_metric_ops)['accuracy'][1]
        log.info("Final evaluation accuracy on split %s: %.4f",
                 cfg.evaluation.split, eval_acc)
    return eval_acc
Ejemplo n.º 6
0
def inception_score(train_dir):
    split = "gan_100_%s" % args.run_name
    if args.inception_file != "":
        split = args.inception_file
    images, _, iter_fn = matcher.load_dataset(split, 100,
                                              args.dataset, args.image_size,
                                              augmentation=False, shuffle=True,
                                              classes=None, normalize=False)

    inception = Inception(images)
    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)
    tf.get_default_graph().finalize()

    real_activations = matcher.load_inception_activations(args.dataset, args.image_size)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        sess.run(init_assign_op, feed_dict=init_feed_dict)
        iter_fn(sess)
        scores, fids = inception.compute_inception_score_and_fid(real_activations, sess, splits=args.inception_splits)
        is_mean, is_std = scores
        fid_mean, fid_std = fids
        log.info("Final Inception score over 50K images: %f +- %f" % (is_mean, is_std))
        log.info("Final Frechet Inception distance over 50K images of %s: %f +- %f" % (split, fid_mean, fid_std))
Ejemplo n.º 7
0
def train_classifier(train_dir, cfg):
    log.info("Training classifier network from the following config: %s",
             str(cfg))
    assert cfg.evaluation.test_set_size % cfg.evaluation.batch_size == 0
    num_test_batches = cfg.evaluation.test_set_size // cfg.evaluation.batch_size

    model_params = {
        'resnet_size': cfg.resnet_size,
        'data_format': 'channels_first',
        'batch_size': cfg.training.batch_size,
        'num_classes': cfg.dataset.num_classes,
    }

    log.info("Creating the graph...")
    images, labels, iter_fn = matcher.load_dataset(cfg.training.split,
                                                   cfg.training.batch_size,
                                                   cfg.dataset.name,
                                                   cfg.dataset.image_size,
                                                   augmentation=True,
                                                   shuffle=True,
                                                   normalize=True,
                                                   onehot=True)
    val_images, val_labels, val_iter_fn = matcher.load_dataset(
        cfg.evaluation.split,
        cfg.evaluation.batch_size,
        cfg.dataset.name,
        cfg.dataset.image_size,
        augmentation=False,
        shuffle=True,
        normalize=True,
        onehot=True,
        dequantize=False)

    with tf.variable_scope("resnet", reuse=None):
        model = cfg.network_fn(images, labels, tf.estimator.ModeKeys.TRAIN,
                               model_params)

    with tf.variable_scope("resnet", reuse=True):
        val_model = cfg.network_fn(val_images, val_labels,
                                   tf.estimator.ModeKeys.EVAL, model_params)

    local_init_op = tf.local_variables_initializer()
    clean_init_op = tf.group(tf.global_variables_initializer(), local_init_op)
    global_step = tf.train.get_or_create_global_step()
    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)
    saver = tf.train.Saver(max_to_keep=1)

    summary_op = tf.summary.merge_all()
    tf.get_default_graph().finalize()

    log.info("Creating the session...")
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        summary_writer = tf.summary.FileWriter(train_dir)
        sess.run(clean_init_op)
        sess.run(init_assign_op, feed_dict=init_feed_dict)
        iter_fn(sess)
        val_iter_fn(sess)

        starting_step = sess.run(global_step)
        starting_time = time.time()

        log.info("Starting training from step %i..." % starting_step)
        for step in range(starting_step, cfg.training.max_iterations + 1):
            start_time = time.time()
            try:
                _, train_loss = sess.run([model.train_op, model.loss])
            except (tf.errors.OutOfRangeError, tf.errors.CancelledError):
                break
            except KeyboardInterrupt:
                log.info("Killed by ^C")
                break

            if step % cfg.training.print_step == 0:
                duration = float(time.time() - start_time)
                examples_per_sec = cfg.training.batch_size / duration
                avg_speed = (time.time() - starting_time) / (step -
                                                             starting_step + 1)
                time_to_finish = datetime.timedelta(
                    seconds=(avg_speed * (cfg.training.max_iterations - step)))
                end_date = datetime.datetime.now() + time_to_finish
                format_str = (
                    'step %d, %.3f (%.1f examples/sec; %.3f sec/batch)')
                log.info(format_str %
                         (step, train_loss, examples_per_sec, duration))
                log.info(
                    "%i iterations left expected to finish after %s, thus at %s (avg speed: %.3f sec/batch)"
                    % (cfg.training.max_iterations - step, str(time_to_finish),
                       end_date.strftime("%Y-%m-%d %H:%M:%S"), avg_speed))

            if step % cfg.training.summary_step == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            if step % cfg.training.ckpt_step == 0 and step > 0:
                summary_writer.flush()
                log.debug("Saving checkpoint...")
                checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                saver.save(sess,
                           checkpoint_path,
                           global_step=step,
                           write_meta_graph=False)

            if step % cfg.training.eval_step == 0:
                sess.run(local_init_op)  # reset counters in metrics
                for _ in range(num_test_batches):
                    eval_acc = sess.run(
                        val_model.eval_metric_ops)['accuracy'][1]
                log.info("Intermediate evaluation accuracy: %.4f" % eval_acc)
                summary = tf.Summary()
                summary.value.add(tag='accuracy/test', simple_value=eval_acc)
                summary_writer.add_summary(summary, step)

        summary_writer.close()

        sess.run(local_init_op)  # reset counters in metrics
        for i in range(num_test_batches):
            eval_metrics = sess.run(val_model.eval_metric_ops)
            eval_acc = eval_metrics['accuracy'][1]
            eval_acc_top5 = eval_metrics['top5_accuracy'][1]
        log.info("Final evaluation accuracy on split %s: %.4f",
                 cfg.evaluation.split, eval_acc)
        if cfg.dataset.name == 'imagenet':
            log.info("Final evaluation top-5 accuracy on split %s: %.4f",
                     cfg.evaluation.split, eval_acc_top5)
        return eval_acc
Ejemplo n.º 8
0
def train(train_dir):
    # XXX subsampling support is dropped silently
    assert abs(args.subsampling - 1) < 0.01
    target_classes = list(range(args.num_classes))
    # XXX classes support is dropped, UPDATE: retrofitted for class splits
    classes = None
    split_training_mode = args.total_class_splits > 0
    num_classes = args.num_classes
    if split_training_mode:
        assert args.num_classes % args.total_class_splits == 0
        split_sz = args.num_classes // args.total_class_splits
        classes_a = args.active_split_num * split_sz
        classes_b = classes_a + split_sz
        classes = list(range(classes_a, classes_b))
        num_classes = split_sz
        log.info("Class split training mode is activated: "
                 "this run chooses %i split out of %i in total, "
                 "thus classes=%s", args.active_split_num,
                 args.total_class_splits, str(classes))
    images, labels, iter_fn = matcher.load_dataset(args.train_split, args.batch_size,
                                                   args.dataset, args.image_size,
                                                   augmentation=False, shuffle=True,
                                                   classes=classes, normalize=True)
    if split_training_mode:
        labels -= classes_a

    noise = tf.random_normal([args.batch_size, args.noise_dims])
    discriminator_train_steps = args.num_discriminator_steps
    generator_train_steps = 1

    def conditional_generator(x):
        return archs[args.arch].generator(x, True, num_classes=num_classes)[0]

    def conditional_discriminator(x, conditioning):
        gan_logits, class_logits, _ = archs[args.arch].discriminator(x, True,
                                                                     gen_input=conditioning,
                                                                     num_classes=num_classes)
        return gan_logits, class_logits

    def unconditional_discriminator(x, conditioning):
        gan_logits, _, _ = archs[args.arch].discriminator(x, True,
                                                          gen_input=conditioning,
                                                          num_classes=num_classes)
        return gan_logits

    one_hot_labels = tf.one_hot(labels, num_classes)
    if args.unconditional:
        gan_model = tfgan.gan_model(
            generator_fn=conditional_generator,
            discriminator_fn=unconditional_discriminator,
            real_data=images,
            generator_inputs=noise)
    elif args.projection:
        gan_model = tfgan.gan_model(
            generator_fn=conditional_generator,
            discriminator_fn=unconditional_discriminator,
            real_data=images,
            generator_inputs=(noise, one_hot_labels))
    else:
        gan_model = tfgan.acgan_model(
            generator_fn=conditional_generator,
            discriminator_fn=conditional_discriminator,
            real_data=images,
            generator_inputs=(noise, one_hot_labels),
            one_hot_labels=one_hot_labels)

    gp = None if abs(args.gradient_penalty) < 0.01 else args.gradient_penalty
    acgan_gw = None if (args.unconditional or abs(args.acgan_gw) < 0.001) else args.acgan_gw
    acgan_dw = None if (args.unconditional or abs(args.acgan_dw) < 0.001) else args.acgan_dw

    if args.gan_loss == 'hinge':
        model_gen_loss = gan_losses.hinge_generator_loss
        model_dis_loss = gan_losses.hinge_discriminator_loss
    elif args.gan_loss == 'wasserstein':
        model_gen_loss = tfgan.losses.wasserstein_generator_loss
        model_dis_loss = tfgan.losses.wasserstein_discriminator_loss
    elif args.gan_loss == 'classical':
        model_gen_loss = tfgan.losses.modified_generator_loss
        model_dis_loss = tfgan.losses.modified_discriminator_loss
    else:
        raise ValueError("Unsupported GAN loss")

    gan_loss = tfgan.gan_loss(
        gan_model,
        generator_loss_fn=model_gen_loss,
        discriminator_loss_fn=model_dis_loss,
        aux_cond_generator_weight=acgan_gw,
        aux_cond_discriminator_weight=acgan_dw,
        gradient_penalty_weight=gp,
    )

    global_step = tf.train.get_or_create_global_step()
    train_ops = tfgan.gan_train_ops(
        gan_model,
        gan_loss,
        generator_optimizer=get_optimizer("generator"),
        discriminator_optimizer=get_optimizer("discriminator"))

    if args.inception_step > 0:
        real_activations = matcher.load_inception_activations(args.dataset, args.image_size)
        inception = Inception(init_generator(args.eval_batch_size, reuse=True, denormalize=True)[0])

    tfgan.eval.add_gan_model_image_summaries(gan_model, grid_size=int(np.sqrt(args.batch_size)))

    init_assign_op, init_feed_dict = utils.restore_ckpt(train_dir, log)
    summary_op = tf.summary.merge_all()
    clean_init_op = tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=1)
    tf.get_default_graph().finalize()

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        summary_writer = tf.summary.FileWriter(train_dir)
        sess.run(clean_init_op)
        sess.run(init_assign_op, feed_dict=init_feed_dict)
        iter_fn(sess)

        starting_step = sess.run(global_step)
        starting_time = time.time()
        log.info("Starting training from step %i..." % starting_step)
        for step in range(starting_step, args.max_iterations+1):
            start_time = time.time()
            try:
                gen_loss = 0
                for _ in range(generator_train_steps):
                    cur_gen_loss = sess.run(train_ops.generator_train_op)
                    gen_loss += cur_gen_loss

                dis_loss = 0
                for _ in range(discriminator_train_steps):
                    cur_dis_loss = sess.run(train_ops.discriminator_train_op)
                    dis_loss += cur_dis_loss

                sess.run(train_ops.global_step_inc_op)
            except (tf.errors.OutOfRangeError, tf.errors.CancelledError):
                break
            except KeyboardInterrupt:
                log.info("Killed by ^C")
                break

            if step % args.print_step == 0:
                duration = float(time.time() - start_time)
                examples_per_sec = args.batch_size / duration
                log.info("step %i: gen loss = %f, dis loss = %f (%.1f examples/sec; %.3f sec/batch)"
                         % (step, gen_loss, dis_loss, examples_per_sec, duration))
                avg_speed = (time.time() - starting_time)/(step - starting_step + 1)
                time_to_finish = avg_speed * (args.max_iterations - step)
                end_date = datetime.datetime.now() + datetime.timedelta(seconds=time_to_finish)
                log.info("%i iterations left expected to finish at %s (avg speed: %.3f sec/batch)"
                        % (args.max_iterations - step, end_date.strftime("%Y-%m-%d %H:%M:%S"), avg_speed))

            if step % args.summary_step == 0:
                summary = tf.Summary()
                summary.ParseFromString(sess.run(summary_op))
                if args.inception_step != 0 and step % args.inception_step == 0 and step > 0:
                    scores, fids = inception.compute_inception_score_and_fid(real_activations, sess)
                    is_mean, is_std = scores
                    fid_mean, fid_std = fids
                    summary.value.add(tag='Inception_50K_mean', simple_value=is_mean)
                    summary.value.add(tag='Inception_50K_std', simple_value=is_std)
                    summary.value.add(tag='FID_50K_mean', simple_value=fid_mean)
                    summary.value.add(tag='FID_50K_std', simple_value=fid_std)
                    log.info("Inception score over 50K images: %f +- %f" % (is_mean, is_std))
                    log.info("Frechet Inception distance over 50K images: %f +- %f" % (fid_mean, fid_std))
                summary_writer.add_summary(summary, step)

            if step % args.ckpt_step == 0 and step >= 0:
                summary_writer.flush()
                log.debug("Saving checkpoint...")
                checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=False)

        summary_writer.close()