Esempio n. 1
0
def binary_op(op_str, a, b):

    if op_str == '+':
        return a + b
    elif op_str == 'Add':
        return Add(a, b)
    elif op_str == '-':
        return a - b
    elif op_str == 'Sub':
        return Subtract(a, b)
    elif op_str == '*':
        return a * b
    elif op_str == 'Mul':
        return Multiply(a, b)
    elif op_str == '/':
        return a / b
    elif op_str == 'Div':
        return Divide(a, b)
    elif op_str == 'Dot':
        return Dot(a, b)
    elif op_str == 'Equal':
        return Equal(a, b)
    elif op_str == 'Greater':
        return Greater(a, b)
    elif op_str == 'GreaterEq':
        return GreaterEq(a, b)
    elif op_str == 'Less':
        return Less(a, b)
    elif op_str == 'LessEq':
        return LessEq(a, b)
    elif op_str == 'Maximum':
        return Maximum(a, b)
    elif op_str == 'Minimum':
        return Minimum(a, b)
    elif op_str == 'NotEqual':
        return NotEqual(a, b)
    elif op_str == 'Power':
        return Power(a, b)
Esempio n. 2
0
def relu(op):  # type: (Node) -> Node
    """Relu operator."""
    return Maximum(op, make_float32_constant_like(0., op))
Esempio n. 3
0
def maximum(left_node,
            right_node,
            name=None):  # type: (NodeInput, NodeInput, str) -> Node
    """Return node which applies the maximum operation to input nodes elementwise."""
    return Maximum(left_node, right_node)
Esempio n. 4
0
from ngraph.impl.op import OneHot

from typing import List, Dict, Set

float_element_type = Type.f32
int_element_type = Type.i32
bz = 53
lr = 0.2

Input = Parameter(float_element_type, Shape([bz, 28, 28]))
Label = Parameter(int_element_type, Shape([bz]))
LabelOneHot = Convert((OneHot(Label, Shape([bz, 10]), 1)), float_element_type)

MaxParam1 = Parameter(float_element_type, Shape([]))
MaxParam2 = Parameter(float_element_type, Shape([]))
MaxFn = Function(Maximum(MaxParam1, MaxParam2), [MaxParam1, MaxParam2],
                 'mnist')


def make_scalar_constant(elem_type, scalar, shape=None, axis_set=None):
    # type: (int, float, List[int], Set[int]) -> float
    """Create a Constant node for scalar value."""
    if shape is None:
        shape = Shape([])
    if axis_set is None:
        axis_set = AxisSet(set())
    scalar_shape = Shape([])  # type: List[int]
    constant_op = Constant(elem_type, scalar_shape, [scalar])
    constant_broadcast = Broadcast(constant_op, shape, axis_set)
    return constant_broadcast