示例#1
0
    def testComputeQmapGridRelu(self):
        """Test checks the compute_qmap_grid function.

    (i) Checks sizes are appropriate and (ii) checks
    accuracy of the numerical values generated by the
    grid by comparing against the analytically known
    form for Relu (Cho and Saul, '09).
    """
        n_gauss, n_var, n_corr = 301, 33, 31
        kernel = nngp.NNGPKernel(nonlin_fn=tf.nn.relu,
                                 n_gauss=n_gauss,
                                 n_var=n_var,
                                 n_corr=n_corr)

        var_aa_grid = kernel.var_aa_grid
        corr_ab_grid = kernel.corr_ab_grid
        qaa_grid = kernel.qaa_grid
        qab_grid = kernel.qab_grid

        qaa_exact = 0.5 * var_aa_grid
        qab_exact = self.ExactQabArcCos(var_aa_grid, corr_ab_grid)

        with self.test_session() as sess:
            self.assertEqual(var_aa_grid.shape.as_list(), [n_var])
            self.assertEqual(corr_ab_grid.shape.as_list(), [n_corr])
            self.assertAllClose(sess.run(qaa_exact),
                                sess.run(qaa_grid),
                                rtol=1e-6)
            self.assertAllClose(sess.run(qab_exact),
                                sess.run(qab_grid),
                                rtol=1e-6,
                                atol=2e-2)
示例#2
0
    def testComputeQmapGridReluLogSpacing(self):
        n_gauss, n_var, n_corr = 301, 33, 31
        kernel = nngp.NNGPKernel(nonlin_fn=tf.nn.relu,
                                 n_gauss=n_gauss,
                                 n_var=n_var,
                                 n_corr=n_corr)

        var_aa_grid = kernel.var_aa_grid
        corr_ab_grid = kernel.corr_ab_grid
        qaa_grid = kernel.qaa_grid
        qab_grid = kernel.qab_grid

        qaa_exact = 0.5 * var_aa_grid
        qab_exact = self.ExactQabArcCos(var_aa_grid, corr_ab_grid)

        with self.test_session() as sess:
            self.assertEqual(var_aa_grid.shape.as_list(), [n_var])
            self.assertEqual(corr_ab_grid.shape.as_list(), [n_corr])
            self.assertAllClose(sess.run(qaa_exact),
                                sess.run(qaa_grid),
                                rtol=1e-6,
                                atol=2e-2)
            self.assertAllClose(sess.run(qab_exact),
                                sess.run(qab_grid),
                                rtol=1e-6,
                                atol=2e-2)
示例#3
0
    def testComputeQmapGridRelu(self):
        """Test checks the compute_qmap_grid function.

        (i) Checks sizes are appropriate and (ii) checks
        accuracy of the numerical values generated by the
        grid by comparing against the analytically known
        form for Relu (Cho and Saul, '09).
        """
        n_gauss, n_var, n_corr = 301, 33, 31
        kernel = nngp.NNGPKernel(nonlin_fn=lambda x: x * (x > 0),
                                 n_gauss=n_gauss,
                                 n_var=n_var,
                                 n_corr=n_corr)

        var_aa_grid = kernel.var_aa_grid
        corr_ab_grid = kernel.corr_ab_grid
        qaa_grid = kernel.qaa_grid
        qab_grid = kernel.qab_grid

        qaa_exact = 0.5 * var_aa_grid
        qab_exact = self.ExactQabArcCos(var_aa_grid, corr_ab_grid)

        self.assertEqual(list(var_aa_grid.shape), [n_var])
        self.assertEqual(list(corr_ab_grid.shape), [n_corr])
        self.assertTrue(np.allclose(qaa_exact, qaa_grid, rtol=1e-6))
        self.assertTrue(np.allclose(qab_exact, qab_grid, rtol=1e-6, atol=2e-2))
示例#4
0
    def testComputeQmapGridEvenNGauss(self):
        n_gauss, n_var, n_corr = 102, 33, 31

        with self.assertRaises(ValueError):
            nngp.NNGPKernel(nonlin_fn=tf.nn.relu,
                            n_gauss=n_gauss,
                            n_var=n_var,
                            n_corr=n_corr)
示例#5
0
    def testGetVarFixedPoint(self):
        n_gauss, n_var, n_corr = 101, 33, 31
        weight_var, bias_var = 1.9, 0.2
        analytic_fixed_point = bias_var / (1. - weight_var / 2)

        kernel = nngp.NNGPKernel(nonlin_fn=tf.nn.relu,
                                 weight_var=weight_var,
                                 bias_var=bias_var,
                                 n_gauss=n_gauss,
                                 n_var=n_var,
                                 n_corr=n_corr)
        fixed_point, _ = kernel.get_var_fixed_point()

        self.assertAllClose(analytic_fixed_point, fixed_point[0], atol=1e-4)
示例#6
0
    def testComputeQmapGridReluLogSpacing(self):
        """Test checks the compute_qmap_grid function with log_spacing=True.
        """
        n_gauss, n_var, n_corr = 301, 33, 31
        kernel = nngp.NNGPKernel(nonlin_fn=lambda x: x * (x > 0),
                                 n_gauss=n_gauss,
                                 n_var=n_var,
                                 n_corr=n_corr)

        var_aa_grid = kernel.var_aa_grid
        corr_ab_grid = kernel.corr_ab_grid
        qaa_grid = kernel.qaa_grid
        qab_grid = kernel.qab_grid

        qaa_exact = 0.5 * var_aa_grid
        qab_exact = self.ExactQabArcCos(var_aa_grid, corr_ab_grid)

        self.assertEqual(list(var_aa_grid.shape), [n_var])
        self.assertEqual(list(corr_ab_grid.shape), [n_corr])
        self.assertTrue(np.allclose(qaa_exact, qaa_grid, rtol=1e-6, atol=2e-2))
        self.assertTrue(np.allclose(qab_exact, qab_grid, rtol=1e-6, atol=2e-2))
def run_nngp_eval(hparams, run_dir):
    """Runs experiments."""

    tf.gfile.MakeDirs(run_dir)
    # Write hparams to experiment directory.
    with tf.gfile.GFile(run_dir + '/hparams', mode='w') as f:
        f.write(hparams.to_proto().SerializeToString())

    tf.logging.info('Starting job.')
    tf.logging.info('Hyperparameters')
    tf.logging.info('---------------------')
    tf.logging.info(hparams)
    tf.logging.info('---------------------')
    tf.logging.info('Loading data')

    # Get the sets of images and labels for training, validation, and
    # # test on dataset.
    if FLAGS.dataset == 'mnist':
        (
            train_image,
            train_label,  #valid_image, valid_label, 
            test_image,
            test_label) = load_dataset.load_mnist(num_train=FLAGS.num_train,
                                                  mean_subtraction=True,
                                                  random_rotated_labels=False)
    elif FLAGS.dataset == 'cifar10':
        (
            train_image,
            train_label,  #valid_image, valid_label, 
            test_image,
            test_label) = load_dataset.load_cifar10(num_train=FLAGS.num_train,
                                                    mean_subtraction=True)
    else:
        raise NotImplementedError

    tf.logging.info('Building Model')

    if hparams.nonlinearity == 'tanh':
        nonlin_fn = tf.tanh
    elif hparams.nonlinearity == 'relu':
        nonlin_fn = tf.nn.relu
    else:
        raise NotImplementedError

    with tf.Session() as sess:
        # Construct NNGP kernel
        nngp_kernel = nngp.NNGPKernel(
            depth=hparams.depth,
            weight_var=hparams.weight_var,
            bias_var=hparams.bias_var,
            mu_2=hparams.mu_2,
            nonlin_fn=nonlin_fn,
            grid_path=FLAGS.grid_path,
            n_gauss=FLAGS.n_gauss,
            n_var=FLAGS.n_var,
            n_corr=FLAGS.n_corr,
            max_gauss=FLAGS.max_gauss,
            max_var=FLAGS.max_var,
            use_fixed_point_norm=FLAGS.use_fixed_point_norm)

        # Construct Gaussian Process Regression model
        model = gpr.GaussianProcessRegression(train_image,
                                              train_label,
                                              kern=nngp_kernel)

        start_time = time.time()
        tf.logging.info('Training')

        # For large number of training points, we do not evaluate on full set to
        # save on training evaluation time.
        if FLAGS.num_train <= 5000:
            acc_train, mse_train, var_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:FLAGS.num_eval],
                train_label[:FLAGS.num_eval])
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' %
                            (min(FLAGS.num_train,
                                 FLAGS.num_eval), time.time() - start_time))
        else:
            acc_train, mse_train, var_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:1000], train_label[:1000])
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' % (1000, time.time() - start_time))

        # start_time = time.time()
        # tf.logging.info('Validation')
        # acc_valid, mse_valid, var_valid, norm_valid, _ = do_eval(
        #     sess, model, valid_image[:FLAGS.num_eval],
        #     valid_label[:FLAGS.num_eval])
        # tf.logging.info('Evaluation of valid set (%d examples) took %.3f secs'%(
        #     FLAGS.num_eval, time.time() - start_time))

        start_time = time.time()
        tf.logging.info('Test')
        acc_test, mse_test, var_test, norm_test, _ = do_eval(
            sess,
            model,
            test_image[:FLAGS.num_eval],
            test_label[:FLAGS.num_eval],
            save_pred=False)
        tf.logging.info('Evaluation of test set (%d examples) took %.3f secs' %
                        (FLAGS.num_eval, time.time() - start_time))

    metrics = {
        'train_acc': float(acc_train),
        'train_mse': float(mse_train),
        'train_norm': float(norm_train),
        # 'valid_acc': float(acc_valid),
        # 'valid_mse': float(mse_valid),
        # 'valid_norm': float(norm_valid),
        'test_acc': float(acc_test),
        'test_mse': float(mse_test),
        'test_norm': float(norm_test),
        'stability_eps': float(final_eps),
    }

    record_results = [
        # FLAGS.num_train, hparams.nonlinearity, hparams.weight_var,
        # hparams.bias_var, hparams.mu_2, hparams.depth, acc_train, acc_valid, acc_test,
        # mse_train, mse_valid, mse_test, final_eps
        FLAGS.num_train,
        hparams.nonlinearity,
        hparams.weight_var,
        hparams.bias_var,
        hparams.mu_2,
        hparams.depth,
        acc_train,
        acc_test,
        var_test,
        norm_train
    ]
    if nngp_kernel.use_fixed_point_norm:
        metrics['var_fixed_point'] = float(nngp_kernel.var_fixed_point_np[0])
        record_results.append(nngp_kernel.var_fixed_point_np[0])

    # Store data
    result_file = os.path.join(run_dir,
                               'results' + str(hparams.depth) + '.csv')
    with tf.gfile.Open(result_file, 'a') as f:
        filewriter = csv.writer(f)
        filewriter.writerow(record_results)
    with tf.Session() as sess:
        varss = np.array(
            [x[0] for x in sess.run(nngp_kernel.layer_qaa_dict).values()])
        save_string = str(hparams.depth) + "_" + str(
            hparams.weight_var) + '_' + str(hparams.mu_2)
        np.save('results/vars/' + save_string, varss)

    return metrics
示例#8
0
def run_nngp_eval(hparams, run_dir):
    """Runs experiments."""

    # Write hparams to experiment directory.
    tf.gfile.MakeDirs(run_dir)
    with tf.gfile.GFile(run_dir + "/hparams", mode="w") as f:
        f.write(hparams.to_proto().SerializeToString())

    tf.logging.info("Starting job.")
    tf.logging.info("Hyperparameters")
    tf.logging.info("---------------------")
    tf.logging.info(hparams)
    tf.logging.info("---------------------")
    tf.logging.info("Loading data")

    # Get the sets of images and labels for training, validation, and test on dataset.
    if FLAGS.dataset == "mnist":
        (
            train_image,
            train_label,
            valid_image,
            valid_label,
            test_image,
            test_label,
        ) = load_dataset.load_mnist(num_train=FLAGS.num_train,
                                    mean_subtraction=True,
                                    random_roated_labels=False)
    else:
        raise NotImplementedError

    tf.logging.info("Building Model")

    if hparams.nonlinearity == "tanh":
        nonlin_fn = tf.tanh
    elif hparams.nonlinearity == "relu":
        nonlin_fn = tf.nn.relu
    else:
        raise NotImplementedError

    with tf.Session() as sess:
        # Construct NNGP kernel
        nngp_kernel = nngp.NNGPKernel(
            depth=hparams.depth,
            weight_var=hparams.weight_var,
            bias_var=hparams.bias_var,
            nonlin_fn=nonlin_fn,
            grid_path=FLAGS.grid_path,
            n_gauss=FLAGS.n_gauss,
            n_var=FLAGS.n_var,
            n_corr=FLAGS.n_corr,
            max_gauss=FLAGS.max_gauss,
            max_var=FLAGS.max_var,
            use_fixed_point_norm=FLAGS.use_fixed_point_norm,
        )

        # Construct Gaussian Process Regression model
        model = gpr.GaussianProcessRegression(train_image,
                                              train_label,
                                              kern=nngp_kernel)

        # 1.Training
        start_time = time.time()
        tf.logging.info("Training")
        # For large number of training points, we do not evaluate on full set to
        # save on training evaluation time.
        if FLAGS.num_train <= 5000:
            acc_train, mse_train, norm_train, final_eps = do_eval(
                sess,
                model,
                train_image[:FLAGS.num_eval],
                train_label[:FLAGS.num_eval],
            )
            tf.logging.info(
                "Evaluation of training set (%d examples) took %.3f secs" %
                (min(FLAGS.num_train,
                     FLAGS.num_eval), time.time() - start_time))
        else:
            acc_train, mse_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:1000], train_label[:1000])
            tf.logging.info(
                "Evaluation of training set (%d examples) took %.3f secs" %
                (1000, time.time() - start_time))

        # 2.Validation
        start_time = time.time()
        tf.logging.info("Validation")
        acc_valid, mse_valid, norm_valid, _ = do_eval(
            sess, model, valid_image[:FLAGS.num_eval],
            valid_label[:FLAGS.num_eval])

        tf.logging.info(
            "Evaluation of valid set (%d examples) took %.3f secs" %
            (FLAGS.num_eval, time.time() - start_time))

        # 3.Test
        start_time = time.time()
        tf.logging.info("Test")
        acc_test, mse_test, norm_test, _ = do_eval(
            sess,
            model,
            test_image[:FLAGS.num_eval],
            test_label[:FLAGS.num_eval],
            save_pred=False,
        )

        tf.logging.info("Evaluation of test set (%d examples) took %.3f secs" %
                        (FLAGS.num_eval, time.time() - start_time))

    metrics = {
        "train_acc": float(acc_train),
        "train_mse": float(mse_train),
        "train_norm": float(norm_train),
        "valid_acc": float(acc_valid),
        "valid_mse": float(mse_valid),
        "valid_norm": float(norm_valid),
        "test_acc": float(acc_test),
        "test_mse": float(mse_test),
        "test_norm": float(norm_test),
        "stability_eps": float(final_eps),
    }

    record_results = [
        FLAGS.num_train,
        hparams.nonlinearity,
        hparams.weight_var,
        hparams.bias_var,
        hparams.depth,
        acc_train,
        acc_valid,
        acc_test,
        mse_train,
        mse_valid,
        mse_test,
        final_eps,
    ]
    if nngp_kernel.use_fixed_point_norm:
        metrics["var_fixed_point"] = float(nngp_kernel.var_fixed_point_np[0])
        record_results.append(nngp_kernel.var_fixed_point_np[0])

    # Store data
    result_file = os.path.join(run_dir, "results.csv")
    with tf.gfile.Open(result_file, "a") as f:
        filewriter = csv.writer(f)
        filewriter.writerow(record_results)

    return metrics
示例#9
0
def run_nngp_eval(hparams, run_dir):
    """Runs experiments."""

    tf.gfile.MakeDirs(run_dir)
    # Write hparams to experiment directory.
    with tf.gfile.GFile(run_dir + '/hparams', mode='w') as f:
        f.write(hparams.to_proto().SerializeToString())

    tf.logging.info('Starting job.')
    tf.logging.info('Hyperparameters')
    tf.logging.info('---------------------')
    tf.logging.info(hparams)
    tf.logging.info('---------------------')
    tf.logging.info('Loading data')

    # Get the sets of images and labels for training, validation, and
    # # test on dataset.
    if FLAGS.dataset == 'mnist':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_mnist(num_train=FLAGS.num_train,
                                               mean_subtraction=True,
                                               random_roated_labels=False)
    elif FLAGS.dataset == 'qi':
        train_image, train_label, test_image, test_label, y_mu, y_std, train_label_normal = load_data(
        )
    else:
        raise NotImplementedError
    tf.logging.info('train X shape: %s' % (str(train_image.shape)))
    tf.logging.info('train y shape: %s' % (str(train_label.shape)))
    tf.logging.info('test X shape: %s' % (str(test_image.shape)))
    tf.logging.info('test y shape: %s' % (str(test_label.shape)))

    #return train_image, train_label, valid_image, valid_label, test_image,test_label, mu_y,std_y

    tf.logging.info('Building Model')

    if hparams.nonlinearity == 'tanh':
        nonlin_fn = tf.tanh
    elif hparams.nonlinearity == 'relu':
        nonlin_fn = tf.nn.relu
    else:
        raise NotImplementedError

    with tf.Session() as sess:
        # Construct NNGP kernel
        nngp_kernel = nngp.NNGPKernel(
            depth=hparams.depth,
            weight_var=hparams.weight_var,
            bias_var=hparams.bias_var,
            nonlin_fn=nonlin_fn,
            grid_path=FLAGS.grid_path,
            n_gauss=FLAGS.n_gauss,
            n_var=FLAGS.n_var,
            n_corr=FLAGS.n_corr,
            max_gauss=FLAGS.max_gauss,
            max_var=FLAGS.max_var,
            use_fixed_point_norm=FLAGS.use_fixed_point_norm)

        # Construct Gaussian Process Regression model
        model = gpr.GaussianProcessRegression(train_image,
                                              train_label_normal,
                                              kern=nngp_kernel)

        if True:
            start_time = time.time()
            tf.logging.info('Training')
            rmse_train = do_eval_qi(sess, model, train_image, train_label,
                                    y_mu, y_std)
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' %
                            ((train_image.shape[0]), time.time() - start_time))

        start_time = time.time()
        tf.logging.info('Test')
        rmse_test = do_eval_qi(sess,
                               model,
                               test_image,
                               test_label,
                               y_mu,
                               y_std,
                               save_pred=False)
        tf.logging.info('Evaluation of test set (%d examples) took %.3f secs' %
                        (test_image.shape[0], time.time() - start_time))

        ## For large number of training points, we do not evaluate on full set to
        ## save on training evaluation time.
        #if FLAGS.num_train <= 5000:
        #  acc_train, mse_train, norm_train, final_eps = do_eval(
        #      sess, model, train_image[:FLAGS.num_eval],
        #      train_label[:FLAGS.num_eval])
        #  tf.logging.info('Evaluation of training set (%d examples) took '
        #                  '%.3f secs'%(
        #                      min(FLAGS.num_train, FLAGS.num_eval),
        #                      time.time() - start_time))
        #else:
        #  acc_train, mse_train, norm_train, final_eps = do_eval(
        #      sess, model, train_image[:1000], train_label[:1000])
        #  tf.logging.info('Evaluation of training set (%d examples) took '
        #                  '%.3f secs'%(1000, time.time() - start_time))

        #start_time = time.time()
        #tf.logging.info('Validation')
        #acc_valid, mse_valid, norm_valid, _ = do_eval(
        #    sess, model, valid_image[:FLAGS.num_eval],
        #    valid_label[:FLAGS.num_eval])
        #tf.logging.info('Evaluation of valid set (%d examples) took %.3f secs'%(
        #    FLAGS.num_eval, time.time() - start_time))

        #start_time = time.time()
        #tf.logging.info('Test')
        #acc_test, mse_test, norm_test, _ = do_eval(
        #    sess,
        #    model,
        #    test_image[:FLAGS.num_eval],
        #    test_label[:FLAGS.num_eval],
        #    save_pred=False)
        #tf.logging.info('Evaluation of test set (%d examples) took %.3f secs'%(
        #    FLAGS.num_eval, time.time() - start_time))

    metrics = {
        'train_rmse': float(rmse_train),
        'test_rmse': float(rmse_test),
    }

    record_results = [
        [
            "FLAGS.train_test_split", "hparams.nonlinearity",
            "hparams.weight_var", "hparams.bias_var", "hparams.depth",
            "rmse_train", "rmse_test"
        ],
        [
            FLAGS.train_test_split, hparams.nonlinearity, hparams.weight_var,
            hparams.bias_var, hparams.depth, rmse_train, rmse_test
        ],
    ]
    if nngp_kernel.use_fixed_point_norm:
        metrics['var_fixed_point'] = float(nngp_kernel.var_fixed_point_np[0])
        record_results[0].append("nngp_kernel.var_fixed_point_np")
        record_results[1].append(nngp_kernel.var_fixed_point_np[0])

    # Store data
    result_file = os.path.join(run_dir, 'results.csv')
    with tf.gfile.Open(result_file, 'a') as f:
        filewriter = csv.writer(f)
        filewriter.writerow(record_results)

    return metrics
示例#10
0
def run_nngp_eval(args):
    """Runs experiments."""

    run_dir = args.experiment_dir
    os.makedirs(run_dir, exist_ok=True)
    hparams = {
        'nonlinearity': args.nonlinearity,
        'weight_var': args.weight_var,
        'bias_var': args.bias_var,
        'depth': args.depth
    }
    # Write hparams to experiment directory.
    with open(run_dir + '/hparams.txt', mode='w') as f:
        f.write(json.dumps(hparams))

    logging.info('Starting job.')
    logging.info('Hyperparameters')
    logging.info('---------------------')
    logging.info(hparams)
    logging.info('---------------------')
    logging.info('Loading data')

    # Get the sets of images and labels for training, validation, and
    # # test on dataset.
    if args.dataset == 'mnist':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_mnist(args,
                                               num_train=args.num_train,
                                               mean_subtraction=True,
                                               random_roated_labels=False)
    else:
        raise NotImplementedError

    logging.info('Building Model')

    if hparams['nonlinearity'] == 'tanh':
        nonlin_fn = lambda x: np.tanh(x)
    elif hparams['nonlinearity'] == 'relu':

        def relu(x):
            return x * (x > 0)

        nonlin_fn = relu
    else:
        raise NotImplementedError

    # Construct NNGP kernel
    nngp_kernel = nngp.NNGPKernel(depth=hparams['depth'],
                                  weight_var=hparams['weight_var'],
                                  bias_var=hparams['bias_var'],
                                  nonlin_fn=nonlin_fn,
                                  grid_path=args.grid_path,
                                  n_gauss=args.n_gauss,
                                  n_var=args.n_var,
                                  n_corr=args.n_corr,
                                  max_gauss=args.max_gauss,
                                  max_var=args.max_var,
                                  use_precomputed_grid=True)

    # Construct Gaussian Process Regression model
    model = gpr.GaussianProcessRegression(train_image,
                                          train_label,
                                          kern=nngp_kernel)

    start_time = time.time()
    logging.info('Training')

    # For large number of training points, we do not evaluate on full set to
    # save on training evaluation time.
    train_size = args.num_eval if args.num_train <= 5000 else 1000
    acc_train, mse_train, norm_train, final_eps = do_eval(
        args, model, train_image[:train_size], train_label[:train_size])
    logging.info('Evaluation of training set ({0} examples) took '
                 '{1:.3f} secs'.format(min(args.num_train, args.num_eval),
                                       time.time() - start_time))

    start_time = time.time()
    logging.info('Validation')
    acc_valid, mse_valid, norm_valid, _ = do_eval(args, model,
                                                  valid_image[:args.num_eval],
                                                  valid_label[:args.num_eval])
    logging.info(
        'Evaluation of valid set ({0} examples) took {1:.3f} secs'.format(
            args.num_eval,
            time.time() - start_time))

    start_time = time.time()
    logging.info('Test')
    acc_test, mse_test, norm_test, _ = do_eval(args, model,
                                               test_image[:args.num_eval],
                                               test_label[:args.num_eval])
    logging.info(
        'Evaluation of valid set ({0} examples) took {1:.3f} secs'.format(
            args.num_eval,
            time.time() - start_time))

    metrics = {
        'train_acc': float(acc_train),
        'train_mse': float(mse_train),
        'train_norm': float(norm_train),
        'valid_acc': float(acc_valid),
        'valid_mse': float(mse_valid),
        'valid_norm': float(norm_valid),
        'test_acc': float(acc_test),
        'test_mse': float(mse_test),
        'test_norm': float(norm_test),
        'stability_eps': float(final_eps),
    }

    record_results = [
        args.num_train, hparams['nonlinearity'], hparams['weight_var'],
        hparams['bias_var'], hparams['depth'], acc_train, acc_valid, acc_test,
        mse_train, mse_valid, mse_test, final_eps
    ]

    # Store data
    result_file = os.path.join(run_dir, 'results.csv')
    with open(result_file, 'a') as f:
        f.write(json.dumps(record_results))

    return metrics
示例#11
0
def run_nngp_eval(hparams, run_dir):
    """Runs experiments."""

    tf.gfile.MakeDirs(run_dir)
    # Write hparams to experiment directory.
    with tf.gfile.GFile(run_dir + '/hparams', mode='w') as f:
        f.write(hparams.to_proto().SerializeToString())

    tf.logging.info('Starting job.')
    tf.logging.info('Hyperparameters')
    tf.logging.info('---------------------')
    tf.logging.info(hparams)
    tf.logging.info('---------------------')
    tf.logging.info('Loading data')

    # Get the sets of images and labels for training, validation, and
    # # test on dataset.
    if FLAGS.dataset == 'mnist':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_mnist(num_train=FLAGS.num_train,
                                               mean_subtraction=True,
                                               random_roated_labels=False)

    elif FLAGS.dataset == 'cifar':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_cifar10(num_train=FLAGS.num_train,
                                                 mean_subtraction=True)

    elif FLAGS.dataset == 'stl10':
        (train_image, train_label, valid_image, valid_label, test_image,
         test_label) = load_dataset.load_stl10(num_train=FLAGS.num_train,
                                               mean_subtraction=True)

    else:
        raise NotImplementedError

    tf.logging.info('Building Model')

    if hparams.nonlinearity == 'tanh':
        nonlin_fn = tf.tanh
    elif hparams.nonlinearity == 'relu':
        nonlin_fn = tf.nn.relu
    else:
        raise NotImplementedError

    session_conf = tf.ConfigProto(intra_op_parallelism_threads=12,
                                  inter_op_parallelism_threads=1)

    with tf.Session(config=session_conf) as sess:
        # Construct NNGP kernel
        nngp_kernel = nngp.NNGPKernel(
            depth=hparams.depth,
            weight_var=hparams.weight_var,
            bias_var=hparams.bias_var,
            nonlin_fn=nonlin_fn,
            grid_path=FLAGS.grid_path,
            n_gauss=FLAGS.n_gauss,
            n_var=FLAGS.n_var,
            n_corr=FLAGS.n_corr,
            max_gauss=FLAGS.max_gauss,
            max_var=FLAGS.max_var,
            use_fixed_point_norm=FLAGS.use_fixed_point_norm)

        # Construct Gaussian Process Regression model
        model = gpr.GaussianProcessRegression(train_image,
                                              train_label,
                                              kern=nngp_kernel)

        start_time = time.time()
        tf.logging.info('Training')

        # For large number of training points, we do not evaluate on full set to
        # save on training evaluation time.
        if FLAGS.num_train <= 5000:
            acc_train, mse_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:FLAGS.num_eval],
                train_label[:FLAGS.num_eval])
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' %
                            (min(FLAGS.num_train,
                                 FLAGS.num_eval), time.time() - start_time))
        else:
            acc_train, mse_train, norm_train, final_eps = do_eval(
                sess, model, train_image[:1000], train_label[:1000])
            tf.logging.info('Evaluation of training set (%d examples) took '
                            '%.3f secs' % (1000, time.time() - start_time))

        vfile = "validation_{0}_{1}_{2}_{3}_{4}_{5}.npy".format(
            FLAGS.dataset, FLAGS.num_train, FLAGS.num_eval, hparams.depth,
            hparams.weight_var, hparams.bias_var)

        start_time = time.time()
        tf.logging.info('Validation')
        acc_valid, mse_valid, norm_valid, _ = do_eval(
            sess,
            model,
            valid_image[:FLAGS.num_eval],
            valid_label[:FLAGS.num_eval],
            save_pred=True,
            fname=vfile)

        tf.logging.info(
            'Evaluation of valid set (%d examples) took %.3f secs' %
            (FLAGS.num_eval, time.time() - start_time))

        tfile = "test_{0}_{1}_{2}_{3}_{4}_{5}.npy".format(
            FLAGS.dataset, FLAGS.num_train, FLAGS.num_eval, hparams.depth,
            hparams.weight_var, hparams.bias_var)

        start_time = time.time()
        tf.logging.info('Test')
        acc_test, mse_test, norm_test, _ = do_eval(sess,
                                                   model,
                                                   test_image[:FLAGS.num_eval],
                                                   test_label[:FLAGS.num_eval],
                                                   save_pred=True,
                                                   fname=tfile)

        tf.logging.info('Evaluation of test set (%d examples) took %.3f secs' %
                        (FLAGS.num_eval, time.time() - start_time))

    metrics = {
        'train_acc': float(acc_train),
        'train_mse': float(mse_train),
        'train_norm': float(norm_train),
        'valid_acc': float(acc_valid),
        'valid_mse': float(mse_valid),
        'valid_norm': float(norm_valid),
        'test_acc': float(acc_test),
        'test_mse': float(mse_test),
        'test_norm': float(norm_test),
        'stability_eps': float(final_eps),
    }

    record_results = [
        FLAGS.num_train, hparams.nonlinearity, hparams.weight_var,
        hparams.bias_var, hparams.depth, acc_train, acc_valid, acc_test,
        mse_train, mse_valid, mse_test, final_eps
    ]
    if nngp_kernel.use_fixed_point_norm:
        metrics['var_fixed_point'] = float(nngp_kernel.var_fixed_point_np[0])
        record_results.append(nngp_kernel.var_fixed_point_np[0])

    # Store data
    rfile = "results_{0}_{1}_{2}_{3}_{4}_{5}.csv".format(FLAGS.dataset, \
            FLAGS.num_train, \
            FLAGS.num_eval, \
            hparams.depth, \
            hparams.weight_var, \
            hparams.bias_var)

    result_file = os.path.join(run_dir, rfile)
    with tf.gfile.Open(result_file, 'a') as f:
        filewriter = csv.writer(f)
        filewriter.writerow(record_results)

    return metrics