Beispiel #1
0
category = formula.POLYHEDRON
formulas = generate_formulas()
formula_list = formulas.get(category)

parent_formula = formula_list[0]
child_formula = formula_list[1]

lower_bound = -1000
upper_bound = 1000

learning_rate = 0.01
training_epochs = 1000
parts_num = 5

data_point_number = 200
util.reset_random_seed()

train_set_x, train_set_y, test_set_x, test_set_y = formula_data_point_generation.generate_partitioned_data(
    parent_formula, category,
    lower_bound,
    upper_bound,
    50, 50)

label_tester = lt.FormulaLabelTester(parent_formula)
point_number_limit = 200
util.reset_random_seed()
model_folder = "models/test-method/test-branch"
model_file = "test-branch"

util.reset_random_seed()
train_acc, test_acc = benchmark.generate_accuracy(train_set_x, train_set_y, test_set_x, test_set_y, learning_rate,
    def train(self):
        print("=========MID_POINT===========")
        train_acc_list = []
        test_acc_list = []
        data_point_number_list = []
        appended_point_list = []
        total_appended_x = []
        total_appended_y = []

        tf.reset_default_graph()
        util.reset_random_seed()

        with tf.Session() as sess:
            net = ns.NNStructure(len(self.train_set_x[0]), self.learning_rate)
            sess.run(net.init)

            count = 1
            while len(self.train_set_x) < self.point_number_limit and count < 10:
                print("*******", count, "th loop:")
                label_0, label_1 = util.data_partition(self.train_set_x, self.train_set_y)
                print("label 0 length", len(label_0), "label 1 length", len(label_1))

                for iteration in range(self.training_epochs):
                    train_x_ = self.train_set_x
                    train_y_ = self.train_set_y
                    sess.run([net.train_op, net.loss_op],
                             feed_dict={net.X: train_x_, net.Y: train_y_})

                train_y = sess.run(net.probability, feed_dict={net.X: self.train_set_x})
                train_acc = util.calculate_accuracy(train_y, self.train_set_y, False)

                print("train_acc", train_acc)
                train_acc_list.append(train_acc)

                if self.test_set_x is not None:
                    test_y = sess.run(net.probability, feed_dict={net.X: self.test_set_x})
                    test_acc = util.calculate_accuracy(test_y, self.test_set_y, False)
                    test_acc_list.append(test_acc)
                    print("test_acc", test_acc)

                data_point_number_list.append(len(self.train_set_x))

                predicted = tf.cast(net.probability > 0.5, dtype=tf.float32)
                util.plot_decision_boundary(lambda x: sess.run(predicted, feed_dict={net.X: x}),
                                            self.train_set_x, self.train_set_y, self.lower_bound, self.upper_bound,
                                            count)

                if len(self.train_set_x) > self.point_number_limit:
                    break

                std_dev = util.calculate_std_dev(self.train_set_x)

                cluster_number_limit = 5
                border_point_number = (int)(self.generalization_valid_limit / (2 * cluster_number_limit)) + 1
                centers, centers_label, clusters, border_points_groups = self.cluster_training_data(border_point_number,
                                                                                                    cluster_number_limit)

                appending_dict = {}
                print("start generalization validation")

                total_appended_x.clear()
                total_appended_y.clear()

                appended_x, appended_y = self.append_generalization_validation_points(sess, net,
                                                                                      std_dev,
                                                                                      centers,
                                                                                      centers_label,
                                                                                      clusters,
                                                                                      border_points_groups,
                                                                                      )
                total_appended_x += appended_x
                total_appended_y += appended_y
                appending_dict["generalization_validation"] = appended_x

                print("start midpoint selection")
                diff_label_pair_list = self.select_diff_label_point_pair(centers, centers_label, clusters)
                appended_x, appended_y = self.append_diff_label_mid_points(sess, net, diff_label_pair_list)
                total_appended_x += appended_x
                total_appended_y += appended_y

                same_label_pair_list = self.select_same_label_point_pair(centers, centers_label, clusters)
                appended_x, appended_y = self.append_same_label_mid_points(same_label_pair_list)
                total_appended_x += appended_x
                total_appended_y += appended_y

                self.train_set_x += total_appended_x
                self.train_set_y += total_appended_y

                appending_dict["mid_point"] = appended_x
                appended_point_list.append(appending_dict)

                util.save_model(sess, self.model_folder, self.model_file)
                count += 1

        communication.send_training_finish_message()
        return train_acc_list, test_acc_list, data_point_number_list, appended_point_list