def dot(left_node, right_node, reduction_axes_count=None, name=None): # type: (Node, Node, int, str) -> Node """Return node which performs generalized dot product of two input nodes. This operation is capable of performing scalar-tensor, matrix-vector product and matrix multiplication. :param left_node: The node providing left hand side data. :param right_node: The node providing right hand side data. :param reduction_axes_count: The number of axes to reduce during dot-product. :param name: The optional name for output node. :return: The new node performing dot-product on input two nodes. """ if reduction_axes_count is None: return Dot(left_node, right_node) else: return Dot(left_node, right_node, reduction_axes_count)
def binary_op(op_str, a, b): if op_str == '+': return a + b elif op_str == 'Add': return Add(a, b) elif op_str == '-': return a - b elif op_str == 'Sub': return Subtract(a, b) elif op_str == '*': return a * b elif op_str == 'Mul': return Multiply(a, b) elif op_str == '/': return a / b elif op_str == 'Div': return Divide(a, b) elif op_str == 'Dot': return Dot(a, b) elif op_str == 'Equal': return Equal(a, b) elif op_str == 'Greater': return Greater(a, b) elif op_str == 'GreaterEq': return GreaterEq(a, b) elif op_str == 'Less': return Less(a, b) elif op_str == 'LessEq': return LessEq(a, b) elif op_str == 'Maximum': return Maximum(a, b) elif op_str == 'Minimum': return Minimum(a, b) elif op_str == 'NotEqual': return NotEqual(a, b) elif op_str == 'Power': return Power(a, b)
def binary_op(op_str, a, b): if op_str == '+': return a + b elif op_str == 'Add': return ng.add(a, b) elif op_str == '-': return a - b elif op_str == 'Sub': return ng.subtract(a, b) elif op_str == '*': return a * b elif op_str == 'Mul': return ng.multiply(a, b) elif op_str == '/': return a / b elif op_str == 'Div': return ng.divide(a, b) elif op_str == 'Dot': return Dot(a, b) elif op_str == 'Equal': return ng.equal(a, b) elif op_str == 'Greater': return ng.greater(a, b) elif op_str == 'GreaterEq': return ng.greater_equal(a, b) elif op_str == 'Less': return ng.less(a, b) elif op_str == 'LessEq': return ng.less_equal(a, b) elif op_str == 'Maximum': return ng.maximum(a, b) elif op_str == 'Minimum': return ng.minimum(a, b) elif op_str == 'NotEqual': return ng.not_equal(a, b) elif op_str == 'Power': return ng.power(a, b)
def dot(left_node, right_node, name=None): # type: (Node, Node, str) -> Node """Return node which performs matrix multiplication of two input nodes.""" return Dot(left_node, right_node)
def relu(op): # type: (Node) -> Node """Relu operator.""" return Maximum(op, make_float32_constant_like(0., op)) # Flatten X1 = Reshape(Input, AxisVector([0, 1, 2]), Shape([bz, 784])) # Normalize X2 = X1 / make_float32_constant_like(255., X1) # Affine 1 W1 = Parameter(float_element_type, Shape([784, 100])) b1 = Parameter(float_element_type, Shape([100])) X3 = Dot(X2, W1) + Broadcast(b1, Shape([bz, 100]), AxisSet({0})) X4 = relu(X3) # Affine 2 W2 = Parameter(float_element_type, Shape([100, 10])) b2 = Parameter(float_element_type, Shape([10])) X5 = Dot(X4, W2) + Broadcast(b2, Shape([bz, 10]), AxisSet({0})) # Softmax Logits = X5 Exp = Exp(Logits) Max = Reduce(Exp, make_float32_constant(0., [], set()), MaxFn, AxisSet({1})) MaxBroadcast = Broadcast(Max, Shape([bz, 10]), AxisSet({1})) Softmax = Exp / MaxBroadcast # Loss