예제 #1
0
class MLP:
    def __init__(self, input_size, hidden_size_1, hidden_size_2, output_size):
        self.input_layer = InputLayer(input_size, hidden_size_1, ReLU)
        self.hidden_layer = HiddenLayer(hidden_size_1, hidden_size_2)
        self.output_layer = SoftmaxOutputLayer(hidden_size_2, output_size)

    def predict(self, x):
        x = self.input_layer.forward(x)
        x = self.hidden_layer.forward(x)
        prob = self.output_layer.predict(x)

        pred = np.argmax(prob, axis=-1)

        return pred

    def loss(self, x, y):
        x = self.input_layer.forward(x)
        x = self.hidden_layer.forward(x)
        loss = self.output_layer.forward(x, y)
        return loss

    def gradient(self):
        d_prev = 1
        d_prev = self.output_layer.backward(d_prev=d_prev)
        d_prev = self.hidden_layer.backward(d_prev)
        self.input_layer.backward(d_prev)

    def update(self, learning_rate):
        self.input_layer.W -= self.input_layer.dW * learning_rate
        self.input_layer.b -= self.input_layer.db * learning_rate
        self.hidden_layer.W -= self.hidden_layer.dW * learning_rate
        self.hidden_layer.b -= self.hidden_layer.db * learning_rate
        self.output_layer.W -= self.output_layer.dW * learning_rate
        self.output_layer.b -= self.output_layer.db * learning_rate
예제 #2
0
 [0.         1.06672439]]
Backward: 
 [[2.8863898  0.20017317]
 [0.96212993 0.06672439]]
dW: 
 [[ 0. -3.]
 [ 0. 10.]]
db: 
 [0. 4.]
"""
hidden_layer = HiddenLayer(2, 2)
hidden_layer.w = np.array([[-5.0, -1.25], [0.01, -10.0]])
hidden_layer.b = np.array([-10, 1.0])
temp6 = np.array([[-1, 3], [0.0, 1.0]])
temp7 = np.array([[-1, 3], [0.0, 1.0]])
print('Forward: \n', hidden_layer.forward(temp6))
print('Backward: \n', hidden_layer.backward(temp7))
print('dW: \n', hidden_layer.dW)
print('db: \n', hidden_layer.db)
print()

print('===== SigmoidOutputLayer Check =====')
"""
The results should be exactly same as below:
결과는 아래와 일치해야 합니다:

Binary Cross-entropy Loss: 
 15.581616346746953
Predict: 
 [[0.90437415]
 [0.00750045]]