示例#1
0
def main():
    cfg = get_config()
    log_path = f"./logs/{cfg['experiment_name']}/"
    logger = get_logger(log_path)
    logger("**************** New Start ****************")
    logger(cfg)
    g = load_generator(cfg)
    print_banner("Generating samples")
    filename = generate_samples(cfg, g)
    logger(f"Samples generated: {filename}.npz")
    print_banner("Generating interpolations")
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
        generate_interpolations(cfg, g, fix_z, fix_y)
    logger("Interpolations generated")
    del g
    torch.cuda.empty_cache()
    # print_banner("Showing some generated samples")
    # show_samples(cfg["experiment_name"], filename)
    print_banner("Calculating IS")
    generated_samples = load_generated_samples(cfg["experiment_name"], filename)
    generated_samples_IS = get_inception_score(generated_samples)
    logger(f"Inception Score: {generated_samples_IS[0]:.6f}+{generated_samples_IS[1]:.6f}")
    print_banner("Calculating FID (training & fake samples)")
    real_samples, _ = load_f100_samples(mode='train')
    fid = get_fid(real_samples, generated_samples)
    logger(f"FID: {fid:.6f}")
    logger("**************** Done ****************")
    def evaluate_inception(self):
        incep_batch_size = self.cfg.EVAL.INCEP_BATCH_SIZE
        logits, _ = load_inception_inference(
            self.sess, self.classes, incep_batch_size,
            self.cfg.EVAL.INCEP_CHECKPOINT_DIR)
        pred_op = tf.nn.softmax(logits)

        z = tf.placeholder(tf.float32, [self.bs, self.model.z_dim], name='z')
        cond = tf.placeholder(tf.float32, [self.bs] + [self.model.embed_dim],
                              name='cond')
        eval_gen, _, _ = self.model.generator(z,
                                              cond,
                                              reuse=False,
                                              is_training=False)

        saver = tf.train.Saver(tf.global_variables('g_net'))
        could_load, _ = load(saver, self.sess, self.cfg.CHECKPOINT_DIR)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            raise RuntimeError(
                'Could not load the checkpoints of the generator')

        print('Generating x...')

        size = self.cfg.EVAL.SIZE
        n_batches = size // self.bs

        w, h, c = self.model.image_dims[0], self.model.image_dims[
            1], self.model.image_dims[2]
        samples = np.zeros((n_batches * self.bs, w, h, c))
        for i in range(n_batches):
            print("\rGenerating batch %d/%d" % (i + 1, n_batches),
                  end="",
                  flush=True)

            sample_z = np.random.normal(0, 1, size=(self.bs, self.model.z_dim))
            _, _, embed, _, _ = self.dataset.test.next_batch(self.bs,
                                                             4,
                                                             embeddings=True)
            start = i * self.bs
            end = start + self.bs

            gen_batch = self.sess.run(eval_gen,
                                      feed_dict={
                                          z: sample_z,
                                          cond: embed
                                      })
            samples[start:end] = denormalize_images(gen_batch)

        print('\nComputing inception score...')
        mean, std = inception_score.get_inception_score(samples,
                                                        self.sess,
                                                        incep_batch_size,
                                                        10,
                                                        pred_op,
                                                        verbose=True)
        print('Inception Score | mean:', "%.2f" % mean, 'std:', "%.2f" % std)
示例#3
0
文件: main.py 项目: chrisbyd/sharegan
def get_inception_score():
    datacfg = data_cfg.get_config(args.dataset)
    generated_pics = np.load(
        os.path.join(SAMPLE, args.model_name, args.dataset,
                     'X_gan_%s.npy' % args.model_name))
    #generated_pics=np.reshape(generated_pics,[-1,datacfg.dataset.image_size,datacfg.dataset.image_size,datacfg.dataset.channels])
    mean, var = inception_score.get_inception_score(generated_pics)
    log.info('the inception score is %s,with d=standard deviation %s' %
             (mean, var))
示例#4
0
def main(_):
    """Evaluate model on Dataset for a number of steps."""
    with tf.Graph().as_default():
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            with tf.device("/gpu:%d" % FLAGS.gpu):
                logits, _ = load_inception_inference(sess, FLAGS.num_classes,
                                                     FLAGS.batch_size,
                                                     FLAGS.checkpoint_dir)
                pred_op = tf.nn.softmax(logits)

                images = load_inception_data(FLAGS.img_folder)
                mean, std = get_inception_score(images, sess, FLAGS.batch_size,
                                                FLAGS.splits, pred_op)
                print('mean:', "%.2f" % mean, 'std:', "%.2f" % std)
示例#5
0
def calculate_baseline(dataset='Face100'):
    log_path = f"./logs/"
    logger = get_logger(log_path, dataset + ".log")
    logger("**************** New Start ****************")
    print("loading samples...")
    if dataset == 'C10':
        train_samples, _ = load_cifar_samples(mode="train")
        test_samples, _ = load_cifar_samples(mode="test")
    elif dataset == 'Face100':
        train_samples, _ = load_f100_samples(mode='train')
        test_samples, _ = load_f100_samples(mode="test")
    else:
        print("unsupported dataset")
        return
    print("loading done")
    baseline_IS = get_inception_score(train_samples)
    logger(f"Inception Score (training samples): {baseline_IS[0]:.6f}+{baseline_IS[1]:.6f}")
    baseline_FID = get_fid(train_samples, test_samples)
    logger(f"FID (between training samples & test samples): {baseline_FID:.6f}")
    logger("**************** Done ****************")