def greater(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node """Return node which checks if left input node is greater than the right node element-wise. :param left_node: The first input node providing data. :param right_node: The second input node providing data. :param name: The optional new name for output node. :return: The node performing element-wise check whether left_node is greater than right_node. """ return Greater(left_node, right_node)
def binary_op(op_str, a, b): if op_str == '+': return a + b elif op_str == 'Add': return Add(a, b) elif op_str == '-': return a - b elif op_str == 'Sub': return Subtract(a, b) elif op_str == '*': return a * b elif op_str == 'Mul': return Multiply(a, b) elif op_str == '/': return a / b elif op_str == 'Div': return Divide(a, b) elif op_str == 'Dot': return Dot(a, b) elif op_str == 'Equal': return Equal(a, b) elif op_str == 'Greater': return Greater(a, b) elif op_str == 'GreaterEq': return GreaterEq(a, b) elif op_str == 'Less': return Less(a, b) elif op_str == 'LessEq': return LessEq(a, b) elif op_str == 'Maximum': return Maximum(a, b) elif op_str == 'Minimum': return Minimum(a, b) elif op_str == 'NotEqual': return NotEqual(a, b) elif op_str == 'Power': return Power(a, b)
def greater(left_node, right_node, name=None): # type: (NodeInput, NodeInput, str) -> Node """Return node which checks if left input node is greater than the right node elementwise.""" return Greater(left_node, right_node)
# Softmax Logits = X5 Exp = Exp(Logits) Max = Reduce(Exp, make_float32_constant(0., [], set()), MaxFn, AxisSet({1})) MaxBroadcast = Broadcast(Max, Shape([bz, 10]), AxisSet({1})) Softmax = Exp / MaxBroadcast # Loss LogSoftmax = Log(Softmax) Loss = Sum(LogSoftmax * LabelOneHot, AxisSet({0, 1})) / make_float32_constant( float(bz), [], set()) # Derivatives dLogits = Softmax - LabelOneHot dX5 = dLogits dX4 = Dot(dX5, transpose(W2, Shape([1, 0]))) dW2 = Dot(transpose(X4, Shape([1, 0])), dX5) db2 = Sum(dX5, AxisSet({0})) dX3 = Convert((Greater(X3, make_float32_constant(0., [bz, 100], {0, 1}))), float_element_type) * dX4 dX2 = Dot(dX3, transpose(W1, Shape([1, 0]))) dW1 = Dot(transpose(X2, Shape([1, 0])), dX3) db1 = Sum(dX3, AxisSet({0})) nW1 = W1 - make_float32_constant_like(lr, dW1) * dW1 nb1 = b1 - make_float32_constant_like(lr, db1) * db1 nW2 = W2 - make_float32_constant_like(lr, dW2) * dW2 nb2 = b2 - make_float32_constant_like(lr, db2) * db2