def run_nngp_eval(hparams, run_dir):
    """Runs experiments."""

    tf.gfile.MakeDirs(run_dir)
    # Write hparams to experiment directory.
    with tf.gfile.GFile(run_dir + '/hparams', mode='w') as f:
        f.write(hparams.to_proto().SerializeToString())

    tf.logging.info('Starting job.')
    tf.logging.info('Hyperparameters')
    tf.logging.info('---------------------')
    tf.logging.info(hparams)
    tf.logging.info('---------------------')
    tf.logging.info('Loading data')

    # Get the sets of images and labels for training, validation, and
    # # test on dataset.
    if FLAGS.dataset == 'mnist':
        (
            train_image,
            train_label,  #valid_image, valid_label, 
            test_image,
            test_label) = load_dataset.load_mnist(num_train=FLAGS.num_train,
                                                  mean_subtraction=True,
                                                  random_rotated_labels=False)
    elif FLAGS.dataset == 'cifar10':
        (
            train_image,
            train_label,  #valid_image, valid_label, 
            test_image,
            test_label) = load_dataset.load_cifar10(num_train=FLAGS.num_train,
                                                    mean_subtraction=True)
    else:
        raise NotImplementedError

    tf.logging.info('Building Model')

    if hparams.nonlinearity == 'tanh':
        nonlin_fn = tf.tanh
    elif hparams.nonlinearity == 'relu':
        nonlin_fn = tf.nn.relu
    else:
        raise NotImplementedError

    with tf.Session() as sess:
        # Construct NNGP kernel
        nngp_kernel = nngp.NNGPKernel(
            depth=hparams.depth,
            weight_var=hparams.weight_var,
            bias_var=hparams.bias_var,
            mu_2=hparams.mu_2,
            nonlin_fn=nonlin_fn,
            grid_path=FLAGS.grid_path,
            n_gauss=FLAGS.n_gauss,
            n_var=FLAGS.n_var,
            n_corr=FLAGS.n_corr,
            max_gauss=FLAGS.max_gauss,
            max_var=FLAGS.max_var,
            use_fixed_point_norm=FLAGS.use_fixed_point_norm)

        # Construct Gaussian Process Regression model
        model = gpr.GaussianProcessRegression(train_image,
                                              train_label,
                                              kern=nngp_kernel)

        start_time = time.time()
        tf.logging.info('Training')

        # For large number of training points, we do not evaluate on full set to
        # save on training evaluation time.
        if FLAGS.num_train <= 5000:
            acc_train, mse_train, var_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:FLAGS.num_eval],
                train_label[:FLAGS.num_eval])
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' %
                            (min(FLAGS.num_train,
                                 FLAGS.num_eval), time.time() - start_time))
        else:
            acc_train, mse_train, var_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:1000], train_label[:1000])
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' % (1000, time.time() - start_time))

        # start_time = time.time()
        # tf.logging.info('Validation')
        # acc_valid, mse_valid, var_valid, norm_valid, _ = do_eval(
        #     sess, model, valid_image[:FLAGS.num_eval],
        #     valid_label[:FLAGS.num_eval])
        # tf.logging.info('Evaluation of valid set (%d examples) took %.3f secs'%(
        #     FLAGS.num_eval, time.time() - start_time))

        start_time = time.time()
        tf.logging.info('Test')
        acc_test, mse_test, var_test, norm_test, _ = do_eval(
            sess,
            model,
            test_image[:FLAGS.num_eval],
            test_label[:FLAGS.num_eval],
            save_pred=False)
        tf.logging.info('Evaluation of test set (%d examples) took %.3f secs' %
                        (FLAGS.num_eval, time.time() - start_time))

    metrics = {
        'train_acc': float(acc_train),
        'train_mse': float(mse_train),
        'train_norm': float(norm_train),
        # 'valid_acc': float(acc_valid),
        # 'valid_mse': float(mse_valid),
        # 'valid_norm': float(norm_valid),
        'test_acc': float(acc_test),
        'test_mse': float(mse_test),
        'test_norm': float(norm_test),
        'stability_eps': float(final_eps),
    }

    record_results = [
        # FLAGS.num_train, hparams.nonlinearity, hparams.weight_var,
        # hparams.bias_var, hparams.mu_2, hparams.depth, acc_train, acc_valid, acc_test,
        # mse_train, mse_valid, mse_test, final_eps
        FLAGS.num_train,
        hparams.nonlinearity,
        hparams.weight_var,
        hparams.bias_var,
        hparams.mu_2,
        hparams.depth,
        acc_train,
        acc_test,
        var_test,
        norm_train
    ]
    if nngp_kernel.use_fixed_point_norm:
        metrics['var_fixed_point'] = float(nngp_kernel.var_fixed_point_np[0])
        record_results.append(nngp_kernel.var_fixed_point_np[0])

    # Store data
    result_file = os.path.join(run_dir,
                               'results' + str(hparams.depth) + '.csv')
    with tf.gfile.Open(result_file, 'a') as f:
        filewriter = csv.writer(f)
        filewriter.writerow(record_results)
    with tf.Session() as sess:
        varss = np.array(
            [x[0] for x in sess.run(nngp_kernel.layer_qaa_dict).values()])
        save_string = str(hparams.depth) + "_" + str(
            hparams.weight_var) + '_' + str(hparams.mu_2)
        np.save('results/vars/' + save_string, varss)

    return metrics
示例#2
0
def run_nngp_eval(hparams, run_dir):
    """Runs experiments."""

    # Write hparams to experiment directory.
    tf.gfile.MakeDirs(run_dir)
    with tf.gfile.GFile(run_dir + "/hparams", mode="w") as f:
        f.write(hparams.to_proto().SerializeToString())

    tf.logging.info("Starting job.")
    tf.logging.info("Hyperparameters")
    tf.logging.info("---------------------")
    tf.logging.info(hparams)
    tf.logging.info("---------------------")
    tf.logging.info("Loading data")

    # Get the sets of images and labels for training, validation, and test on dataset.
    if FLAGS.dataset == "mnist":
        (
            train_image,
            train_label,
            valid_image,
            valid_label,
            test_image,
            test_label,
        ) = load_dataset.load_mnist(num_train=FLAGS.num_train,
                                    mean_subtraction=True,
                                    random_roated_labels=False)
    else:
        raise NotImplementedError

    tf.logging.info("Building Model")

    if hparams.nonlinearity == "tanh":
        nonlin_fn = tf.tanh
    elif hparams.nonlinearity == "relu":
        nonlin_fn = tf.nn.relu
    else:
        raise NotImplementedError

    with tf.Session() as sess:
        # Construct NNGP kernel
        nngp_kernel = nngp.NNGPKernel(
            depth=hparams.depth,
            weight_var=hparams.weight_var,
            bias_var=hparams.bias_var,
            nonlin_fn=nonlin_fn,
            grid_path=FLAGS.grid_path,
            n_gauss=FLAGS.n_gauss,
            n_var=FLAGS.n_var,
            n_corr=FLAGS.n_corr,
            max_gauss=FLAGS.max_gauss,
            max_var=FLAGS.max_var,
            use_fixed_point_norm=FLAGS.use_fixed_point_norm,
        )

        # Construct Gaussian Process Regression model
        model = gpr.GaussianProcessRegression(train_image,
                                              train_label,
                                              kern=nngp_kernel)

        # 1.Training
        start_time = time.time()
        tf.logging.info("Training")
        # For large number of training points, we do not evaluate on full set to
        # save on training evaluation time.
        if FLAGS.num_train <= 5000:
            acc_train, mse_train, norm_train, final_eps = do_eval(
                sess,
                model,
                train_image[:FLAGS.num_eval],
                train_label[:FLAGS.num_eval],
            )
            tf.logging.info(
                "Evaluation of training set (%d examples) took %.3f secs" %
                (min(FLAGS.num_train,
                     FLAGS.num_eval), time.time() - start_time))
        else:
            acc_train, mse_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:1000], train_label[:1000])
            tf.logging.info(
                "Evaluation of training set (%d examples) took %.3f secs" %
                (1000, time.time() - start_time))

        # 2.Validation
        start_time = time.time()
        tf.logging.info("Validation")
        acc_valid, mse_valid, norm_valid, _ = do_eval(
            sess, model, valid_image[:FLAGS.num_eval],
            valid_label[:FLAGS.num_eval])

        tf.logging.info(
            "Evaluation of valid set (%d examples) took %.3f secs" %
            (FLAGS.num_eval, time.time() - start_time))

        # 3.Test
        start_time = time.time()
        tf.logging.info("Test")
        acc_test, mse_test, norm_test, _ = do_eval(
            sess,
            model,
            test_image[:FLAGS.num_eval],
            test_label[:FLAGS.num_eval],
            save_pred=False,
        )

        tf.logging.info("Evaluation of test set (%d examples) took %.3f secs" %
                        (FLAGS.num_eval, time.time() - start_time))

    metrics = {
        "train_acc": float(acc_train),
        "train_mse": float(mse_train),
        "train_norm": float(norm_train),
        "valid_acc": float(acc_valid),
        "valid_mse": float(mse_valid),
        "valid_norm": float(norm_valid),
        "test_acc": float(acc_test),
        "test_mse": float(mse_test),
        "test_norm": float(norm_test),
        "stability_eps": float(final_eps),
    }

    record_results = [
        FLAGS.num_train,
        hparams.nonlinearity,
        hparams.weight_var,
        hparams.bias_var,
        hparams.depth,
        acc_train,
        acc_valid,
        acc_test,
        mse_train,
        mse_valid,
        mse_test,
        final_eps,
    ]
    if nngp_kernel.use_fixed_point_norm:
        metrics["var_fixed_point"] = float(nngp_kernel.var_fixed_point_np[0])
        record_results.append(nngp_kernel.var_fixed_point_np[0])

    # Store data
    result_file = os.path.join(run_dir, "results.csv")
    with tf.gfile.Open(result_file, "a") as f:
        filewriter = csv.writer(f)
        filewriter.writerow(record_results)

    return metrics
示例#3
0
def run_nngp_eval(args):
    """Runs experiments."""

    run_dir = args.experiment_dir
    os.makedirs(run_dir, exist_ok=True)
    hparams = {
        'nonlinearity': args.nonlinearity,
        'weight_var': args.weight_var,
        'bias_var': args.bias_var,
        'depth': args.depth
    }
    # Write hparams to experiment directory.
    with open(run_dir + '/hparams.txt', mode='w') as f:
        f.write(json.dumps(hparams))

    logging.info('Starting job.')
    logging.info('Hyperparameters')
    logging.info('---------------------')
    logging.info(hparams)
    logging.info('---------------------')
    logging.info('Loading data')

    # Get the sets of images and labels for training, validation, and
    # # test on dataset.
    if args.dataset == 'mnist':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_mnist(args,
                                               num_train=args.num_train,
                                               mean_subtraction=True,
                                               random_roated_labels=False)
    else:
        raise NotImplementedError

    logging.info('Building Model')

    if hparams['nonlinearity'] == 'tanh':
        nonlin_fn = lambda x: np.tanh(x)
    elif hparams['nonlinearity'] == 'relu':

        def relu(x):
            return x * (x > 0)

        nonlin_fn = relu
    else:
        raise NotImplementedError

    # Construct NNGP kernel
    nngp_kernel = nngp.NNGPKernel(depth=hparams['depth'],
                                  weight_var=hparams['weight_var'],
                                  bias_var=hparams['bias_var'],
                                  nonlin_fn=nonlin_fn,
                                  grid_path=args.grid_path,
                                  n_gauss=args.n_gauss,
                                  n_var=args.n_var,
                                  n_corr=args.n_corr,
                                  max_gauss=args.max_gauss,
                                  max_var=args.max_var,
                                  use_precomputed_grid=True)

    # Construct Gaussian Process Regression model
    model = gpr.GaussianProcessRegression(train_image,
                                          train_label,
                                          kern=nngp_kernel)

    start_time = time.time()
    logging.info('Training')

    # For large number of training points, we do not evaluate on full set to
    # save on training evaluation time.
    train_size = args.num_eval if args.num_train <= 5000 else 1000
    acc_train, mse_train, norm_train, final_eps = do_eval(
        args, model, train_image[:train_size], train_label[:train_size])
    logging.info('Evaluation of training set ({0} examples) took '
                 '{1:.3f} secs'.format(min(args.num_train, args.num_eval),
                                       time.time() - start_time))

    start_time = time.time()
    logging.info('Validation')
    acc_valid, mse_valid, norm_valid, _ = do_eval(args, model,
                                                  valid_image[:args.num_eval],
                                                  valid_label[:args.num_eval])
    logging.info(
        'Evaluation of valid set ({0} examples) took {1:.3f} secs'.format(
            args.num_eval,
            time.time() - start_time))

    start_time = time.time()
    logging.info('Test')
    acc_test, mse_test, norm_test, _ = do_eval(args, model,
                                               test_image[:args.num_eval],
                                               test_label[:args.num_eval])
    logging.info(
        'Evaluation of valid set ({0} examples) took {1:.3f} secs'.format(
            args.num_eval,
            time.time() - start_time))

    metrics = {
        'train_acc': float(acc_train),
        'train_mse': float(mse_train),
        'train_norm': float(norm_train),
        'valid_acc': float(acc_valid),
        'valid_mse': float(mse_valid),
        'valid_norm': float(norm_valid),
        'test_acc': float(acc_test),
        'test_mse': float(mse_test),
        'test_norm': float(norm_test),
        'stability_eps': float(final_eps),
    }

    record_results = [
        args.num_train, hparams['nonlinearity'], hparams['weight_var'],
        hparams['bias_var'], hparams['depth'], acc_train, acc_valid, acc_test,
        mse_train, mse_valid, mse_test, final_eps
    ]

    # Store data
    result_file = os.path.join(run_dir, 'results.csv')
    with open(result_file, 'a') as f:
        f.write(json.dumps(record_results))

    return metrics
示例#4
0
def run_nngp_eval(hparams, run_dir):
    """Runs experiments."""

    tf.gfile.MakeDirs(run_dir)
    # Write hparams to experiment directory.
    with tf.gfile.GFile(run_dir + '/hparams', mode='w') as f:
        f.write(hparams.to_proto().SerializeToString())

    tf.logging.info('Starting job.')
    tf.logging.info('Hyperparameters')
    tf.logging.info('---------------------')
    tf.logging.info(hparams)
    tf.logging.info('---------------------')
    tf.logging.info('Loading data')

    # Get the sets of images and labels for training, validation, and
    # # test on dataset.
    if FLAGS.dataset == 'mnist':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_mnist(num_train=FLAGS.num_train,
                                               mean_subtraction=True,
                                               random_roated_labels=False)
    elif FLAGS.dataset == 'qi':
        train_image, train_label, test_image, test_label, y_mu, y_std, train_label_normal = load_data(
        )
    else:
        raise NotImplementedError
    tf.logging.info('train X shape: %s' % (str(train_image.shape)))
    tf.logging.info('train y shape: %s' % (str(train_label.shape)))
    tf.logging.info('test X shape: %s' % (str(test_image.shape)))
    tf.logging.info('test y shape: %s' % (str(test_label.shape)))

    #return train_image, train_label, valid_image, valid_label, test_image,test_label, mu_y,std_y

    tf.logging.info('Building Model')

    if hparams.nonlinearity == 'tanh':
        nonlin_fn = tf.tanh
    elif hparams.nonlinearity == 'relu':
        nonlin_fn = tf.nn.relu
    else:
        raise NotImplementedError

    with tf.Session() as sess:
        # Construct NNGP kernel
        nngp_kernel = nngp.NNGPKernel(
            depth=hparams.depth,
            weight_var=hparams.weight_var,
            bias_var=hparams.bias_var,
            nonlin_fn=nonlin_fn,
            grid_path=FLAGS.grid_path,
            n_gauss=FLAGS.n_gauss,
            n_var=FLAGS.n_var,
            n_corr=FLAGS.n_corr,
            max_gauss=FLAGS.max_gauss,
            max_var=FLAGS.max_var,
            use_fixed_point_norm=FLAGS.use_fixed_point_norm)

        # Construct Gaussian Process Regression model
        model = gpr.GaussianProcessRegression(train_image,
                                              train_label_normal,
                                              kern=nngp_kernel)

        if True:
            start_time = time.time()
            tf.logging.info('Training')
            rmse_train = do_eval_qi(sess, model, train_image, train_label,
                                    y_mu, y_std)
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' %
                            ((train_image.shape[0]), time.time() - start_time))

        start_time = time.time()
        tf.logging.info('Test')
        rmse_test = do_eval_qi(sess,
                               model,
                               test_image,
                               test_label,
                               y_mu,
                               y_std,
                               save_pred=False)
        tf.logging.info('Evaluation of test set (%d examples) took %.3f secs' %
                        (test_image.shape[0], time.time() - start_time))

        ## For large number of training points, we do not evaluate on full set to
        ## save on training evaluation time.
        #if FLAGS.num_train <= 5000:
        #  acc_train, mse_train, norm_train, final_eps = do_eval(
        #      sess, model, train_image[:FLAGS.num_eval],
        #      train_label[:FLAGS.num_eval])
        #  tf.logging.info('Evaluation of training set (%d examples) took '
        #                  '%.3f secs'%(
        #                      min(FLAGS.num_train, FLAGS.num_eval),
        #                      time.time() - start_time))
        #else:
        #  acc_train, mse_train, norm_train, final_eps = do_eval(
        #      sess, model, train_image[:1000], train_label[:1000])
        #  tf.logging.info('Evaluation of training set (%d examples) took '
        #                  '%.3f secs'%(1000, time.time() - start_time))

        #start_time = time.time()
        #tf.logging.info('Validation')
        #acc_valid, mse_valid, norm_valid, _ = do_eval(
        #    sess, model, valid_image[:FLAGS.num_eval],
        #    valid_label[:FLAGS.num_eval])
        #tf.logging.info('Evaluation of valid set (%d examples) took %.3f secs'%(
        #    FLAGS.num_eval, time.time() - start_time))

        #start_time = time.time()
        #tf.logging.info('Test')
        #acc_test, mse_test, norm_test, _ = do_eval(
        #    sess,
        #    model,
        #    test_image[:FLAGS.num_eval],
        #    test_label[:FLAGS.num_eval],
        #    save_pred=False)
        #tf.logging.info('Evaluation of test set (%d examples) took %.3f secs'%(
        #    FLAGS.num_eval, time.time() - start_time))

    metrics = {
        'train_rmse': float(rmse_train),
        'test_rmse': float(rmse_test),
    }

    record_results = [
        [
            "FLAGS.train_test_split", "hparams.nonlinearity",
            "hparams.weight_var", "hparams.bias_var", "hparams.depth",
            "rmse_train", "rmse_test"
        ],
        [
            FLAGS.train_test_split, hparams.nonlinearity, hparams.weight_var,
            hparams.bias_var, hparams.depth, rmse_train, rmse_test
        ],
    ]
    if nngp_kernel.use_fixed_point_norm:
        metrics['var_fixed_point'] = float(nngp_kernel.var_fixed_point_np[0])
        record_results[0].append("nngp_kernel.var_fixed_point_np")
        record_results[1].append(nngp_kernel.var_fixed_point_np[0])

    # Store data
    result_file = os.path.join(run_dir, 'results.csv')
    with tf.gfile.Open(result_file, 'a') as f:
        filewriter = csv.writer(f)
        filewriter.writerow(record_results)

    return metrics
示例#5
0
def run_nngp_eval(hparams, run_dir):
    """Runs experiments."""

    tf.gfile.MakeDirs(run_dir)
    # Write hparams to experiment directory.
    with tf.gfile.GFile(run_dir + '/hparams', mode='w') as f:
        f.write(hparams.to_proto().SerializeToString())

    tf.logging.info('Starting job.')
    tf.logging.info('Hyperparameters')
    tf.logging.info('---------------------')
    tf.logging.info(hparams)
    tf.logging.info('---------------------')
    tf.logging.info('Loading data')

    # Get the sets of images and labels for training, validation, and
    # # test on dataset.
    if FLAGS.dataset == 'mnist':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_mnist(num_train=FLAGS.num_train,
                                               mean_subtraction=True,
                                               random_roated_labels=False)

    elif FLAGS.dataset == 'cifar':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_cifar10(num_train=FLAGS.num_train,
                                                 mean_subtraction=True)

    elif FLAGS.dataset == 'stl10':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_stl10(num_train=FLAGS.num_train,
                                               mean_subtraction=True)

    else:
        raise NotImplementedError

    tf.logging.info('Building Model')

    if hparams.nonlinearity == 'tanh':
        nonlin_fn = tf.tanh
    elif hparams.nonlinearity == 'relu':
        nonlin_fn = tf.nn.relu
    else:
        raise NotImplementedError

    session_conf = tf.ConfigProto(intra_op_parallelism_threads=12,
                                  inter_op_parallelism_threads=1)

    with tf.Session(config=session_conf) as sess:
        # Construct NNGP kernel
        nngp_kernel = nngp.NNGPKernel(
            depth=hparams.depth,
            weight_var=hparams.weight_var,
            bias_var=hparams.bias_var,
            nonlin_fn=nonlin_fn,
            grid_path=FLAGS.grid_path,
            n_gauss=FLAGS.n_gauss,
            n_var=FLAGS.n_var,
            n_corr=FLAGS.n_corr,
            max_gauss=FLAGS.max_gauss,
            max_var=FLAGS.max_var,
            use_fixed_point_norm=FLAGS.use_fixed_point_norm)

        # Construct Gaussian Process Regression model
        model = gpr.GaussianProcessRegression(train_image,
                                              train_label,
                                              kern=nngp_kernel)

        start_time = time.time()
        tf.logging.info('Training')

        # For large number of training points, we do not evaluate on full set to
        # save on training evaluation time.
        if FLAGS.num_train <= 5000:
            acc_train, mse_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:FLAGS.num_eval],
                train_label[:FLAGS.num_eval])
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' %
                            (min(FLAGS.num_train,
                                 FLAGS.num_eval), time.time() - start_time))
        else:
            acc_train, mse_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:1000], train_label[:1000])
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' % (1000, time.time() - start_time))

        vfile = "validation_{0}_{1}_{2}_{3}_{4}_{5}.npy".format(
            FLAGS.dataset, FLAGS.num_train, FLAGS.num_eval, hparams.depth,
            hparams.weight_var, hparams.bias_var)

        start_time = time.time()
        tf.logging.info('Validation')
        acc_valid, mse_valid, norm_valid, _ = do_eval(
            sess,
            model,
            valid_image[:FLAGS.num_eval],
            valid_label[:FLAGS.num_eval],
            save_pred=True,
            fname=vfile)

        tf.logging.info(
            'Evaluation of valid set (%d examples) took %.3f secs' %
            (FLAGS.num_eval, time.time() - start_time))

        tfile = "test_{0}_{1}_{2}_{3}_{4}_{5}.npy".format(
            FLAGS.dataset, FLAGS.num_train, FLAGS.num_eval, hparams.depth,
            hparams.weight_var, hparams.bias_var)

        start_time = time.time()
        tf.logging.info('Test')
        acc_test, mse_test, norm_test, _ = do_eval(sess,
                                                   model,
                                                   test_image[:FLAGS.num_eval],
                                                   test_label[:FLAGS.num_eval],
                                                   save_pred=True,
                                                   fname=tfile)

        tf.logging.info('Evaluation of test set (%d examples) took %.3f secs' %
                        (FLAGS.num_eval, time.time() - start_time))

    metrics = {
        'train_acc': float(acc_train),
        'train_mse': float(mse_train),
        'train_norm': float(norm_train),
        'valid_acc': float(acc_valid),
        'valid_mse': float(mse_valid),
        'valid_norm': float(norm_valid),
        'test_acc': float(acc_test),
        'test_mse': float(mse_test),
        'test_norm': float(norm_test),
        'stability_eps': float(final_eps),
    }

    record_results = [
        FLAGS.num_train, hparams.nonlinearity, hparams.weight_var,
        hparams.bias_var, hparams.depth, acc_train, acc_valid, acc_test,
        mse_train, mse_valid, mse_test, final_eps
    ]
    if nngp_kernel.use_fixed_point_norm:
        metrics['var_fixed_point'] = float(nngp_kernel.var_fixed_point_np[0])
        record_results.append(nngp_kernel.var_fixed_point_np[0])

    # Store data
    rfile = "results_{0}_{1}_{2}_{3}_{4}_{5}.csv".format(FLAGS.dataset, \
            FLAGS.num_train, \
            FLAGS.num_eval, \
            hparams.depth, \
            hparams.weight_var, \
            hparams.bias_var)

    result_file = os.path.join(run_dir, rfile)
    with tf.gfile.Open(result_file, 'a') as f:
        filewriter = csv.writer(f)
        filewriter.writerow(record_results)

    return metrics