for i in range(10):
        chosen_schedule_start = int(schedule_starts[i])
        for each_t in range(chosen_schedule_start, chosen_schedule_start + 20):
            optimizer.zero_grad()
            pred = ddt(x_test[each_t])
            loss = F.cross_entropy(pred.reshape(1, 2), y_test[each_t].long())
            acc = (pred.argmax(dim=-1) == y_test[each_t].item()).to(
                torch.float32).mean()
            test_losses.append(loss.item())
            test_accs.append(acc.mean().item())
    print('Loss: {}, Accuracy: {}'.format(np.mean(test_losses),
                                          np.mean(test_accs)))
print('Finished Training')

### REAL TEST
ddt = convert_to_crisp(ddt, None)

x_data_test, y_test, percent_of_zeros = create_simple_classification_dataset(
    50, True)

x_test = []
for each_ele in x_data_test:
    x_test.append(each_ele[2:])

x_test = torch.Tensor(x_test).reshape(-1, 1, 2)
y_test = torch.Tensor(y_test).reshape((-1, 1))
test_losses, test_accs = [], []
per_schedule_test_losses, per_schedule_test_accs = [], []
preds, actual = [[] for _ in range(50)], [[] for _ in range(50)]
for i in range(50):
    chosen_schedule_start = int(schedule_starts[i])
    def test_again_crisp(self, model, test_embeddings):
        """
        Evaluate performance of a trained network tuned upon the alpha divergence loss.
        This is tested on 20% of the data and will be stored in a text file.
        Note this function is called after training convergence
        :return:
        """
        # define new optimizer that only optimizes gradient

        self.model = convert_to_crisp(model, None)
        num_schedules = 100
        # load in new data
        load_directory = '/home/ghost/PycharmProjects/bayesian_prolo/scheduling_env/datasets/' + str(
            num_schedules) + 'test_dist_early_hili_pairwise.pkl'
        sig = torch.nn.Sigmoid()
        data = pickle.load(open(load_directory, "rb"))
        X, Y, schedule_array = create_new_data(num_schedules, data)

        prediction_accuracy = [0, 0]
        percentage_accuracy_top1 = []
        percentage_accuracy_top3 = []
        embedding_optimizer = torch.optim.SGD([{'params': self.model.bayesian_embedding.parameters()}], lr=.01)
        criterion = torch.nn.BCELoss()

        embedding_list = test_embeddings

        for j in range(0, num_schedules):
            schedule_bounds = schedule_array[j]
            step = schedule_bounds[0]
            model.set_bayesian_embedding(embedding_list[j])

            while step < schedule_bounds[1]:
                probability_matrix = np.zeros((20, 20))

                for m, counter in enumerate(range(step, step + 20)):
                    phi_i = X[counter]
                    phi_i_numpy = np.asarray(phi_i)

                    # for each set of twenty
                    for n, second_counter in enumerate(range(step, step + 20)):
                        # fill entire array with diagnols set to zero
                        if second_counter == counter:  # same as m = n
                            continue
                        phi_j = X[second_counter]
                        phi_j_numpy = np.asarray(phi_j)

                        feature_input = phi_i_numpy - phi_j_numpy

                        if torch.cuda.is_available():
                            feature_input = Variable(torch.Tensor(feature_input.reshape(1, 13)).cuda())

                        else:
                            feature_input = Variable(torch.Tensor(feature_input.reshape(1, 13)))

                        # push through nets
                        preference_prob = model.forward(feature_input)
                        sig = torch.nn.Sigmoid()
                        preference_prob = sig(preference_prob)
                        probability_matrix[m][n] = preference_prob[0].data.detach()[
                            0].item()  # TODO: you can do a check if only this line leads to the same thing as the line below
                        # probability_matrix[n][m] = preference_prob[0].data.detach()[1].item()

                # Set of twenty is completed
                column_vec = np.sum(probability_matrix, axis=1)

                embedding_list[j] = torch.Tensor(self.model.get_bayesian_embedding().detach().cpu().numpy())  # very ugly

                # top 1
                # given all inputs, and their liklihood of being scheduled, predict the output
                highest_val = max(column_vec)
                all_indexes_that_have_highest_val = [i for i, e in enumerate(list(column_vec)) if e == highest_val]
                if len(all_indexes_that_have_highest_val) > 1:
                    print('length of indexes greater than 1: ', all_indexes_that_have_highest_val)
                # top 1
                choice = np.random.choice(all_indexes_that_have_highest_val)
                # choice = np.argmax(probability_vector)

                # top 3
                _, top_three = torch.topk(torch.Tensor(column_vec), 3)

                # Then do training update loop
                truth = Y[step]

                # index top 1
                if choice == truth:
                    prediction_accuracy[0] += 1

                # index top 3
                if truth in top_three:
                    prediction_accuracy[1] += 1

                # add average loss to array
                step += 20

            # schedule finished
            print('Prediction Accuracy: top1: ', prediction_accuracy[0] / 20, ' top3: ', prediction_accuracy[1] / 20)

            print('schedule num:', j)
            percentage_accuracy_top1.append(prediction_accuracy[0] / 20)
            percentage_accuracy_top3.append(prediction_accuracy[1] / 20)
            embedding_list[j] = torch.Tensor(self.model.get_bayesian_embedding().detach().cpu().numpy())  # very ugly

            prediction_accuracy = [0, 0]
        print(np.mean(prediction_accuracy[0]))
def test(ddt):

    x_data_test, y_test, percent_of_zeros = create_simple_classification_dataset(
        50, get_percent_of_zeros=True)
    schedule_starts = np.linspace(0, int(50 * 20 - 20), num=50)
    x_test = []

    for each_ele in x_data_test:
        x_test.append(each_ele[2:])

    x_test = torch.Tensor(x_test).reshape(-1, 1, 2)
    y_test = torch.Tensor(y_test).reshape((-1, 1))

    test_losses, test_accs = [], []
    per_schedule_test_losses, per_schedule_test_accs = [], []
    preds, actual = [[] for _ in range(50)], [[] for _ in range(50)]
    test_distributions = [np.ones(2) * 1 / 2 for _ in range(50)]
    total_acc = []
    for i in range(50):
        chosen_schedule_start = int(schedule_starts[i])
        schedule_num = int(chosen_schedule_start / 20)
        embedding_given_dis, count = get_embedding_given_dist(
            test_distributions[schedule_num])
        prod = [.5, .5]
        acc = 0
        ddt.set_bayesian_embedding(embedding_given_dis)

        for each_t in range(chosen_schedule_start, chosen_schedule_start + 20):
            # at each timestep you what to resample the embedding

            x_t = x_test[each_t]
            output = ddt.forward(x_t).reshape(1, 2)

            label = y_test[each_t]
            label = torch.Tensor([label]).reshape(1)
            label = label.long()
            print('output is ',
                  torch.argmax(output).item(), ' label is ', label.item())
            if torch.argmax(output).item() == label.item():
                acc += 1
            tally = output[0][int(label.item())].item()
            second_tally = output[0][int(not label.item())].item()
            prod[count] = tally * test_distributions[i][count]
            prod[int(not count
                     )] *= second_tally * test_distributions[i][int(not count)]
            preds[i].append(torch.argmax(output).item())
            actual[i].append(label.item())
            normalization_factor = sum(prod)
            prod = [k / normalization_factor for k in prod]

            test_distributions[schedule_num][0] = prod[0]
            test_distributions[schedule_num][1] = prod[1]
            normalization_factor_for_dist = sum(
                test_distributions[schedule_num])
            test_distributions[
                schedule_num] /= normalization_factor_for_dist  # [i/normalization_factor_for_dist for i in distributions[schedule_num]]
            print('distribution at time ', each_t, ' is',
                  test_distributions[schedule_num])
            if each_t % 20 < 5:
                embedding_given_dis, count = get_embedding_given_dist(
                    test_distributions[schedule_num])
            else:
                embedding_given_dis = get_most_likely_embedding_given_dist(
                    test_distributions[schedule_num])
            ddt.set_bayesian_embedding(embedding_given_dis)

        per_schedule_test_accs.append(acc / 20)
    # print('Loss: {}, Accuracy: {}'.format(0, np.mean(per_schedule_test_accs)))
    print(test_distributions)
    print('per sched accuracy: ', np.mean(per_schedule_test_accs))
    sensitivity, specificity = compute_sensitivity(
        preds, actual), compute_specificity(preds, actual)

    test_losses, test_accs = [], []
    per_schedule_test_losses, per_schedule_test_accs = [], []
    preds, actual = [[] for _ in range(50)], [[] for _ in range(50)]
    total_acc = []
    for i in range(50):
        chosen_schedule_start = int(schedule_starts[i])
        schedule_num = int(chosen_schedule_start / 20)
        embedding_given_dis, count = get_embedding_given_dist(
            test_distributions[schedule_num])
        prod = [.5, .5]
        acc = 0
        ddt.set_bayesian_embedding(embedding_given_dis)

        for each_t in range(chosen_schedule_start, chosen_schedule_start + 20):
            # at each timestep you what to resample the embedding

            x_t = x_test[each_t]
            output = ddt.forward(x_t).reshape(1, 2)

            label = y_test[each_t]
            label = torch.Tensor([label]).reshape(1)
            label = label.long()
            # print('output is ', torch.argmax(output).item(), ' label is ', label.item())
            if torch.argmax(output).item() == label.item():
                acc += 1
            tally = output[0][int(label.item())].item()
            second_tally = output[0][int(not label.item())].item()
            prod[count] = tally * test_distributions[i][count]
            prod[int(not count
                     )] *= second_tally * test_distributions[i][int(not count)]
            preds[i].append(torch.argmax(output).item())
            actual[i].append(label.item())

        per_schedule_test_accs.append(acc / 20)
    # print('Loss: {}, Accuracy: {}'.format(0, np.mean(per_schedule_test_accs)))
    print('per sched accuracy: ', np.mean(per_schedule_test_accs))
    fuzzy_sensitivity, fuzzy_specificity = compute_sensitivity(
        preds, actual), compute_specificity(preds, actual)
    fuzzy_accuracy = np.mean(per_schedule_test_accs)

    ddt = convert_to_crisp(ddt, None)

    test_losses, test_accs = [], []
    per_schedule_test_losses, per_schedule_test_accs = [], []
    preds, actual = [[] for _ in range(50)], [[] for _ in range(50)]

    total_acc = []
    for i in range(50):
        chosen_schedule_start = int(schedule_starts[i])
        schedule_num = int(chosen_schedule_start / 20)
        embedding_given_dis, count = get_embedding_given_dist(
            test_distributions[schedule_num])
        prod = [.5, .5]
        acc = 0
        ddt.set_bayesian_embedding(embedding_given_dis)

        for each_t in range(chosen_schedule_start, chosen_schedule_start + 20):
            # at each timestep you what to resample the embedding

            x_t = x_test[each_t]
            output = ddt.forward(x_t).reshape(1, 2)

            label = y_test[each_t]
            label = torch.Tensor([label]).reshape(1)
            label = label.long()
            # print('output is ', torch.argmax(output).item(), ' label is ', label.item())
            if torch.argmax(output).item() == label.item():
                acc += 1
            tally = output[0][int(label.item())].item()
            second_tally = output[0][int(not label.item())].item()
            prod[count] = tally * test_distributions[i][count]
            prod[int(not count
                     )] *= second_tally * test_distributions[i][int(not count)]
            preds[i].append(torch.argmax(output).item())
            actual[i].append(label.item())

        per_schedule_test_accs.append(acc / 20)
    # print('Loss: {}, Accuracy: {}'.format(0, np.mean(per_schedule_test_accs)))
    print('per sched accuracy: ', np.mean(per_schedule_test_accs))
    crisp_sensitivity, crisp_specificity = compute_sensitivity(
        preds, actual), compute_specificity(preds, actual)
    crisp_accuracy = np.mean(per_schedule_test_accs)

    print('mean crisp sensitivity: ', crisp_sensitivity,
          ', mean crisp specificity: ', crisp_specificity)
    file = open('heterogeneous_toy_env_results.txt', 'a')
    file.write('crisp DDT w/ bimodal embedding: crisp mean: ' +
               str(crisp_accuracy) + ', fuzzy mean: ' + str(fuzzy_accuracy) +
               ', crisp sensitivity: ' + str(crisp_sensitivity) +
               ', crisp specificity: ' + str(crisp_specificity) +
               ', fuzzy sensitivity: ' + str(fuzzy_sensitivity) +
               ', fuzzy specificity: ' + str(fuzzy_specificity) +
               ', Distribution of Class: 0: ' + str(percent_of_zeros) +
               ', 1: ' + str(1 - percent_of_zeros) + '\n')
    file.close()
    def train(self):
        """
        Trains BDT.
        Randomly samples a schedule and timestep within that schedule, produces training data using x_i - x_j
        and trains upon that.
        :return:
        """

        threshold = .1
        training_done = False
        loss_func = AlphaLoss()

        # deepening data
        deepen_data = {'samples': [], 'labels': [], 'embedding_indices': []}

        # variables to keep track of loss and number of tasks trained over
        running_loss_predict_tasks = 0
        num_iterations_predict_task = 0

        while not training_done:
            # sample a timestep before the cutoff for cross_validation
            rand_timestep_within_sched = np.random.randint(
                len(self.start_of_each_set_twenty))
            set_of_twenty = self.start_of_each_set_twenty[
                rand_timestep_within_sched]
            truth = self.Y[set_of_twenty]

            which_schedule = find_which_schedule_this_belongs_to(
                self.schedule_array, set_of_twenty)
            load_in_embedding(self.model, self.embedding_list, which_schedule)

            # find feature vector of true action taken
            phi_i_num = truth + set_of_twenty
            phi_i = self.X[phi_i_num]
            phi_i_numpy = np.asarray(phi_i)

            # iterate over pairwise comparisons
            for counter in range(set_of_twenty, set_of_twenty + 20):
                if counter == phi_i_num:  # if counter == phi_i_num:
                    continue
                else:
                    phi_j = self.X[counter]
                    phi_j_numpy = np.asarray(phi_j)
                    feature_input = phi_i_numpy - phi_j_numpy
                    deepen_data['samples'].append(np.array(feature_input))
                    # label = add_noise_pairwise(label, self.noise_percentage)
                    if torch.cuda.is_available():
                        feature_input = Variable(
                            torch.Tensor(feature_input.reshape(1, 13)).cuda())
                        P = Variable(
                            torch.Tensor([
                                1 - self.distribution_epsilon,
                                self.distribution_epsilon
                            ]).cuda())
                    else:
                        feature_input = Variable(
                            torch.Tensor(feature_input.reshape(1, 13)))
                        P = Variable(
                            torch.Tensor([
                                1 - self.distribution_epsilon,
                                self.distribution_epsilon
                            ]))

                    output = self.model(feature_input)
                    loss = loss_func.forward(P, output, self.alpha)

                    # NAN check (fix is in the bnn file)
                    if torch.isnan(loss):
                        print(self.alpha, ' :nan occurred at iteration ',
                              self.total_iterations)

                    # prepare optimizer, compute gradient, update params
                    self.opt.zero_grad()
                    if loss.item() < .001 or loss.item() > 50:
                        pass
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                       0.5)
                        self.opt.step()

                    running_loss_predict_tasks += loss.item()
                    num_iterations_predict_task += 1

                    deepen_data['labels'].extend([0])
                    deepen_data['embedding_indices'].extend([which_schedule])

            for counter in range(set_of_twenty, set_of_twenty + 20):
                if counter == phi_i_num:
                    continue
                else:
                    phi_j = self.X[counter]
                    phi_j_numpy = np.asarray(phi_j)
                    feature_input = phi_j_numpy - phi_i_numpy
                    deepen_data['samples'].append(np.array(feature_input))
                    if torch.cuda.is_available():
                        feature_input = Variable(
                            torch.Tensor(feature_input.reshape(1, 13)).cuda())
                        P = Variable(
                            torch.Tensor([
                                self.distribution_epsilon,
                                1 - self.distribution_epsilon
                            ]).cuda())
                    else:
                        feature_input = Variable(
                            torch.Tensor(feature_input.reshape(1, 13)))
                        P = Variable(
                            torch.Tensor([
                                self.distribution_epsilon,
                                1 - self.distribution_epsilon
                            ]))

                    output = self.model(feature_input)
                    loss = loss_func.forward(P, output, self.alpha)
                    # if num_iterations_predict_task % 5 == 0:
                    #     print('loss is :', loss.item())
                    # clip any very high gradients

                    # prepare optimizer, compute gradient, update params
                    self.opt.zero_grad()
                    if loss.item() < .001 or loss.item() > 50:
                        pass
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                       0.5)
                        self.opt.step()

                    running_loss_predict_tasks += loss.item()
                    num_iterations_predict_task += 1

                    deepen_data['labels'].extend([1])
                    deepen_data['embedding_indices'].extend([which_schedule])

            # add average loss to array
            # print(list(self.model.parameters()))

            self.embedding_list = store_embedding_back(self.model,
                                                       self.embedding_list,
                                                       which_schedule)
            self.total_loss_array.append(running_loss_predict_tasks /
                                         num_iterations_predict_task)
            num_iterations_predict_task = 0
            running_loss_predict_tasks = 0

            self.total_iterations += 1

            if self.total_iterations > 25 and self.total_iterations % 50 == 1:
                print('total iterations is', self.total_iterations)
                print('total loss (average for each 40, averaged)',
                      np.mean(self.total_loss_array[-40:]))

            if self.total_iterations > 0 and self.total_iterations % self.when_to_save == self.when_to_save - 1:
                self.save_trained_nets('BDDT' + str(self.num_schedules))
                threshold -= .025

            # if self.total_iterations % 500 == 499:
            #     # self.model = deepen_with_embeddings(self.model, deepen_data, self.embedding_list, max_depth=self.max_depth, threshold=threshold)
            #     params = list(self.model.parameters())
            #     del params[0]
            #     self.opt = torch.optim.RMSprop([{'params': params}, {'params': self.model.bayesian_embedding, 'lr': .001}])
            # deepen_data = {
            #     'samples': [],
            #     'labels': [],
            #     'embedding_indices': []
            # }

            if self.total_iterations > 2500 and np.mean(
                    self.total_loss_array[-100:]) - np.mean(
                        self.total_loss_array[-500:]) < self.covergence_epsilon:
                training_done = True

                self.model = convert_to_crisp(self.model, None)