コード例 #1
0
    def test_multilayer_perceptron_train(self):
        # mlnn = MultiLayerNeuralNetwork([2, 2, 1], threshold=0.5, learning_coefficient=0.5, sigmoid_alpha=5)
        mlnn = MultiLayerNeuralNetwork([2, 4, 1], threshold=0.5, learning_coefficient=0.5, sigmoid_alpha=5, print_error=False)
        #mlnn = MultiLayerNeuralNetwork([2, 4, 4, 1], threshold=0.5, learning_coefficient=0.5, sigmoid_alpha=5)

        train_data = [
                [[0,0],[1]],
                [[1,0],[0]],
                [[0,1],[0]],
                [[1,1],[1]],
                ]

        # train
        mlnn.train(train_data)

        # print train_data
        # print mlnn.weights
        # print mlnn.predict([0, 0])
        # print mlnn.predict([1, 0])
        # print mlnn.predict([0, 1])
        # print mlnn.predict([1, 1])
        assert mlnn.predict([0, 0]) > mlnn.predict([1, 0])
        assert mlnn.predict([0, 0]) > mlnn.predict([0, 1])
        assert mlnn.predict([1, 1]) > mlnn.predict([1, 0])
        assert mlnn.predict([1, 1]) > mlnn.predict([0, 1])
コード例 #2
0
ファイル: main.py プロジェクト: kokukuma/EoNN
def multilayer_perceptron():

    mlnn = MultiLayerNeuralNetwork( [2, 4, 1],
                                    threshold=0.5,
                                    learning_coefficient=0.5,
                                    sigmoid_alpha=5)

    x_range = [0,10]
    y_range = [0,10]
    # liner_data = TrainingData.liner_training_data(x_range, y_range)
    #liner_data = TrainingData.quadratic_function_data(x_range, y_range)
    liner_data = TrainingData.sin_function_data(x_range, y_range, 5)

    train_data = TrainingData.change_format(liner_data)

    # 教師データのプロット
    fig = plt.figure()
    scat(fig, [key for key, value in liner_data.items() if value == 0], color='g' )
    scat(fig, [key for key, value in liner_data.items() if value == 1], color='b' )

    # 学習
    sample_border = len(train_data)
    random.shuffle(train_data)
    #mlnn.train(train_data[:20])
    mlnn.train(train_data[:sample_border])

    # xに対応するyを算出
    data = get_predict_list(x_range,y_range, mlnn)

    # 学習後分離線
    plot(fig, data)

    plt.show()