from ch05.layer_naive import MulLayer, AddLayer apple = 100 apple_num = 2 orange = 150 orange_num = 3 tax = 1.1 # Layers mul_apple_layer = MulLayer() mul_orange_layer = MulLayer() add_apple_orange_layer = AddLayer() mul_tax_layer = MulLayer() # Forward propagation apple_price = mul_apple_layer.forward(apple, apple_num) orange_price = mul_orange_layer.forward(orange, orange_num) all_price = add_apple_orange_layer.forward(apple_price, orange_price) price = mul_tax_layer.forward(all_price, tax) # Backward propagation dprice = 1 dall_price, dtax = mul_tax_layer.backward(price) dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price) dorange, dorange_num = mul_orange_layer.backward(dorange_price) dapple, dapple_num = mul_apple_layer.backward(dapple_price) print(price) print(dapple_num, dapple, dorange_num, dorange, dtax)
# coding: utf-8 from ch05.layer_naive import MulLayer apple = 100 apple_num = 2 tax = 1.1 mul_apple_layer = MulLayer() mul_tax_layer = MulLayer() # forward apple_price = mul_apple_layer.forward(apple, apple_num) price = mul_tax_layer.forward(apple_price, tax) # backward dprice = 1 dapple_price, dtax = mul_tax_layer.backward(dprice) dapple, dapple_num = mul_apple_layer.backward(dapple_price) print("price:", int(price)) print("dApple:", dapple) print("dApple_num:", int(dapple_num)) print("dTax:", dtax)
from ch05.layer_naive import MulLayer, AddLayer apple = 100 orange = 150 apple_num = 2 orange_num = 3 tax = 1.1 mul_app_layer = MulLayer() mul_org_layer = MulLayer() add_layer = AddLayer() mul_tax_layer = MulLayer() # forward apple_price = mul_app_layer.forward(apple, apple_num) orange_price = mul_org_layer.forward(orange, orange_num) add_price = add_layer.forward(apple_price, orange_price) price = mul_tax_layer.forward(add_price, tax) print(price) # backward dprice = 1 dadd_price, dtax = mul_tax_layer.backward(dprice) dapple, dorange = add_layer.backward(dadd_price) dorange_price, dorange_num = mul_org_layer.backward(dorange) dapple_price, dapple_num = mul_app_layer.backward(dapple) print(dtax, dapple_price, dapple_num, dorange_price, dorange_num)