def test_nearest_neighbors_k_48(self):
        problem = ToyProblem()
        model = KnnModel(problem, k=48)
        nn_values = scipy.io.loadmat(
            "tests/matlab_variables/nearest_neighbors_48nn.mat")
        nn_values = nn_values['nearest_neighbors']
        nn_values = nn_values.T - 1
        # remove edge points, messes w/ nearest neighbor comparison

        # first get indices of edge points
        edge_indices_x_max = np.where(problem.points[:, 0] >= 47, 0, 1)
        edge_indices_x_min = np.where(problem.points[:, 0] <= 4, 0, 1)
        edge_indices_x = np.multiply(edge_indices_x_min, edge_indices_x_max)

        edge_indices_y_max = np.where(problem.points[:, 1] >= 47, 0, 1)
        edge_indices_y_min = np.where(problem.points[:, 1] <= 4, 0, 1)
        edge_indices_y = np.multiply(edge_indices_y_min, edge_indices_y_max)

        edge_indices = np.multiply(edge_indices_x, edge_indices_y)

        # next, remove entries from both nearest neighbors
        nn_values = np.multiply(nn_values, edge_indices[:, None])
        model.ind = np.multiply(model.ind, edge_indices[:, None])

        assert np.all(np.sort(nn_values, axis=1) == np.sort(model.ind, axis=1))
    def test_ENS_4nn_single_iteration(self):
        budget = 99
        problem = ToyProblem(jitter=True)
        model = KnnModel(problem, k=4)
        currentData = Data()

        weight_matrix_matlab = scipy.io.loadmat(
            "tests/matlab_variables/weights_4nn_jitter.mat")
        weight_matrix_matlab = weight_matrix_matlab['weights']
        nearest_neighbors_matlab = scipy.io.loadmat(
            "tests/matlab_variables/nearest_neighbors_4nn_jitter.mat")
        nearest_neighbors_matlab = nearest_neighbors_matlab[
            'nearest_neighbors'] - 1

        model.weight_matrix = weight_matrix_matlab
        model.ind = nearest_neighbors_matlab.T
        expected_scores = scipy.io.loadmat(
            "tests/matlab_variables/ens_utilities_4nn.mat")
        expected_scores = expected_scores['utilities']

        utility = ENS()
        selector = UnlabelSelector()

        policy = ArgMaxPolicy(problem, model, utility)
        np.random.seed(3)
        positive_indices = [
            i for i, x in enumerate(problem.labels_deterministic) if x > 0
        ]

        firstObsIndex = positive_indices[0]

        currentData = Data()

        firstPointValue = problem.oracle_function(firstObsIndex)
        #print("first point value:",self.oracle_function(firstObsIndex))
        currentData.new_observation(firstObsIndex, firstPointValue)
        #test_indices = np.array([444, 588, 1692, 1909, 2203, 2208, 2268])

        test_indices = selector.filter(currentData, problem.points, model,
                                       policy, problem, budget)

        expected_test_indices = scipy.io.loadmat(
            "tests/matlab_variables/ens_test_indices_4nn.mat")
        expected_test_indices = expected_test_indices['test_ind'] - 1

        expected_test_indices = np.sort(expected_test_indices, axis=0)

        #compare test_indices
        for index, expected_index in zip(test_indices, expected_test_indices):
            assert index == expected_index

        scores = utility.get_scores(model, currentData, test_indices, budget,
                                    problem.points)
        print(problem.points)
        for score, expected in zip(scores, expected_scores):
            assert score == pytest.approx(expected)
    def test_ENS_4nn_every_iteration_with_pruning(self):

        budget = 99
        problem = ToyProblem(jitter=True)
        model = KnnModel(problem, k=4)
        currentData = Data()

        weight_matrix_matlab = scipy.io.loadmat(
            "tests/matlab_variables/weights_4nn_jitter.mat")
        weight_matrix_matlab = weight_matrix_matlab['weights']
        nearest_neighbors_matlab = scipy.io.loadmat(
            "tests/matlab_variables/nearest_neighbors_4nn_jitter.mat")
        nearest_neighbors_matlab = nearest_neighbors_matlab[
            'nearest_neighbors'] - 1

        model.weight_matrix = weight_matrix_matlab
        model.ind = nearest_neighbors_matlab.T

        #declare 2 instances of selectors, ENS_no_pruning and ENS_pruning
        utility = ENS()
        selector = UnlabelSelector()

        policy = ENSPolicy(problem, model, utility)
        np.random.seed(3)
        positive_indices = [
            i for i, x in enumerate(problem.labels_deterministic) if x > 0
        ]

        firstObsIndex = positive_indices[0]

        currentData = Data()

        firstPointValue = problem.oracle_function(firstObsIndex)
        #print("first point value:",self.oracle_function(firstObsIndex))
        currentData.new_observation(firstObsIndex, firstPointValue)
        #test_indices = np.array([444, 588, 1692, 1909, 2203, 2208, 2268])

        while budget > 0:

            test_indices = selector.filter(currentData, problem.points, model,
                                           policy, problem, budget)

            budget_string = 'budget' + str(budget)

            probabilities = policy.model.predict(currentData, test_indices)
            argsort_ind = (-probabilities).argsort(axis=0)
            probabilities = probabilities[argsort_ind[:, 0]]
            test_indices = test_indices[argsort_ind[:, 0]]

            #indices_argsorter = np.argsort(test_indices)

            scores = utility.get_scores(model, currentData, test_indices,
                                        budget, problem.points, probabilities)

            #scores = utility.get_scores(model, currentData, this_iter_expected_test_indices,budget,problem.points)

            max_index = np.argmax(scores)

            this_iter_expected_scores = expected_scores[budget_string][0][0]
            #print(this_iter_expected_scores)
            #np.savetxt('bound.txt', this_iter_expected_scores, fmt='%10.5f', delimiter=' ')
            #np.savetxt('bound2.txt', scores, fmt='%10.5f', delimiter=' ')

            for score, expected in zip(scores, this_iter_expected_scores):
                assert score == pytest.approx(expected, abs=1e-13)

            chosen_x_index = test_indices[max_index]

            #assert chosen_x_index==expected_selected_indices[100-budget]

            this_expected_selected_index = expected_selected_indices[
                budget_string]

            if chosen_x_index != this_expected_selected_index - 1:
                #import pdb; pdb.set_trace()
                warnings.warn(
                    UserWarning(
                        "chosen index doesnt match up, however expected scores may match. replaced chosen index"
                    ))
                chosen_x_index = this_expected_selected_index[0][0][0][0] - 1

            y = problem.oracle_function(chosen_x_index)
            currentData.new_observation(chosen_x_index, y)

            budget -= 1
            if budget == 97:
                import pdb
                pdb.set_trace()
    def test_one_step_4nn_every_iteration(self):

        budget = 99
        problem = ToyProblem(jitter=True)
        model = KnnModel(problem, k=4)
        currentData = Data()

        weight_matrix_matlab = scipy.io.loadmat(
            "tests/matlab_variables/weights_4nn_jitter.mat")
        weight_matrix_matlab = weight_matrix_matlab['weights']
        nearest_neighbors_matlab = scipy.io.loadmat(
            "tests/matlab_variables/nearest_neighbors_4nn_jitter.mat")
        nearest_neighbors_matlab = nearest_neighbors_matlab[
            'nearest_neighbors'] - 1

        model.weight_matrix = weight_matrix_matlab
        model.ind = nearest_neighbors_matlab.T
        expected_scores = scipy.io.loadmat(
            "tests/matlab_variables/one_step_utilities_every_iter_4nn.mat")
        expected_scores = expected_scores['test_policies_utilities']

        expected_test_indices = scipy.io.loadmat(
            "tests/matlab_variables/one_step_test_indices_every_iter_4nn.mat")
        expected_test_indices = expected_test_indices[
            'test_policies_utilities']

        #expected_selected_indices = scipy.io.loadmat("tests/matlab_variables/expected_selected_indices_every_iter_4nn.mat")
        #expected_selected_indices = expected_selected_indices['train_and_selected_ind']-1

        utility = OneStep()
        selector = UnlabelSelector()

        policy = ArgMaxPolicy(problem, model, utility)
        np.random.seed(3)
        positive_indices = [
            i for i, x in enumerate(problem.labels_deterministic) if x > 0
        ]

        firstObsIndex = positive_indices[0]

        currentData = Data()

        firstPointValue = problem.oracle_function(firstObsIndex)
        #print("first point value:",self.oracle_function(firstObsIndex))
        currentData.new_observation(firstObsIndex, firstPointValue)
        #test_indices = np.array([444, 588, 1692, 1909, 2203, 2208, 2268])

        while budget > 0:

            test_indices = selector.filter(currentData, problem.points, model,
                                           policy, problem, budget)

            budget_string = 'budget' + str(budget + 1)
            #expected_test_indices['budget98']
            this_iter_expected_test_indices = expected_test_indices[
                budget_string] - 1
            this_iter_expected_test_indices = this_iter_expected_test_indices[
                0][0].reshape(-1, )
            #print(this_iter_expected_test_indices[0][0])

            #compare test_indices
            for index, expected_index in zip(test_indices,
                                             this_iter_expected_test_indices):
                assert index == expected_index

            #print(test_indices.shape)
            #print(this_iter_expected_test_indices.reshape(-1,).shape)
            scores = utility.get_scores(model, currentData,
                                        this_iter_expected_test_indices,
                                        budget, problem.points)

            max_index = np.argmax(scores)

            this_iter_expected_scores = expected_scores[budget_string][0][0]
            #print(this_iter_expected_scores)

            for score, expected in zip(scores, this_iter_expected_scores):
                assert score == pytest.approx(expected, abs=1e-13)

            chosen_x_index = this_iter_expected_test_indices[max_index]

            #assert chosen_x_index==expected_selected_indices[100-budget]

            #if chosen_x_index!=expected_selected_indices[100-budget]:
            #  warnings.warn(UserWarning("chosen index doesnt match up, however expected scores may match. replaced chosen index"))
            #  chosen_x_index=expected_selected_indices[100-budget][0]

            y = problem.oracle_function(chosen_x_index)
            currentData.new_observation(chosen_x_index, y)

            budget -= 1