Ejemplo n.º 1
0
def create_simple_if_with_two_outputs(condition_val):
    condition = ov.constant(condition_val, dtype=np.bool)

    # then_body
    X_t = ov.parameter([], np.float32, "X")
    Y_t = ov.parameter([], np.float32, "Y")
    Z_t = ov.parameter([], np.float32, "Z")

    add_t = ov.add(X_t, Y_t)
    mul_t = ov.multiply(Y_t, Z_t)
    then_body_res_1 = ov.result(add_t)
    then_body_res_2 = ov.result(mul_t)
    then_body = GraphBody([X_t, Y_t, Z_t], [then_body_res_1, then_body_res_2])
    then_body_inputs = [
        TensorIteratorInvariantInputDesc(1, 0),
        TensorIteratorInvariantInputDesc(2, 1),
        TensorIteratorInvariantInputDesc(3, 2)
    ]
    then_body_outputs = [
        TensorIteratorBodyOutputDesc(0, 0),
        TensorIteratorBodyOutputDesc(1, 1)
    ]

    # else_body
    X_e = ov.parameter([], np.float32, "X")
    Z_e = ov.parameter([], np.float32, "Z")
    W_e = ov.parameter([], np.float32, "W")

    add_e = ov.add(X_e, W_e)
    pow_e = ov.power(W_e, Z_e)
    else_body_res_1 = ov.result(add_e)
    else_body_res_2 = ov.result(pow_e)
    else_body = GraphBody([X_e, Z_e, W_e], [else_body_res_1, else_body_res_2])
    else_body_inputs = [
        TensorIteratorInvariantInputDesc(1, 0),
        TensorIteratorInvariantInputDesc(3, 1),
        TensorIteratorInvariantInputDesc(4, 2)
    ]
    else_body_outputs = [
        TensorIteratorBodyOutputDesc(0, 0),
        TensorIteratorBodyOutputDesc(1, 1)
    ]

    X = ov.constant(15.0, dtype=np.float32)
    Y = ov.constant(-5.0, dtype=np.float32)
    Z = ov.constant(4.0, dtype=np.float32)
    W = ov.constant(2.0, dtype=np.float32)
    if_node = ov.if_op(condition, [X, Y, Z, W], (then_body, else_body),
                       (then_body_inputs, else_body_inputs),
                       (then_body_outputs, else_body_outputs))
    return if_node
Ejemplo n.º 2
0
def create_simple_if_with_two_outputs(condition_val):
    condition = ov.constant(condition_val, dtype=np.bool)

    # then_body
    X_t = ov.parameter([], np.float32, "X")
    Y_t = ov.parameter([], np.float32, "Y")
    Z_t = ov.parameter([], np.float32, "Z")

    add_t = ov.add(X_t, Y_t)
    mul_t = ov.multiply(Y_t, Z_t)
    then_body_res_1 = ov.result(add_t)
    then_body_res_2 = ov.result(mul_t)
    then_body = Model([then_body_res_1, then_body_res_2], [X_t, Y_t, Z_t],
                      "then_body_function")

    # else_body
    X_e = ov.parameter([], np.float32, "X")
    Z_e = ov.parameter([], np.float32, "Z")
    W_e = ov.parameter([], np.float32, "W")

    add_e = ov.add(X_e, W_e)
    pow_e = ov.power(W_e, Z_e)
    else_body_res_1 = ov.result(add_e)
    else_body_res_2 = ov.result(pow_e)
    else_body = Model([else_body_res_1, else_body_res_2], [X_e, Z_e, W_e],
                      "else_body_function")

    X = ov.constant(15.0, dtype=np.float32)
    Y = ov.constant(-5.0, dtype=np.float32)
    Z = ov.constant(4.0, dtype=np.float32)
    W = ov.constant(2.0, dtype=np.float32)

    if_node = ov.if_op(condition)
    if_node.set_then_body(then_body)
    if_node.set_else_body(else_body)
    if_node.set_input(X.output(0), X_t, X_e)
    if_node.set_input(Y.output(0), Y_t, None)
    if_node.set_input(Z.output(0), Z_t, Z_e)
    if_node.set_input(W.output(0), None, W_e)
    if_node.set_output(then_body_res_1, else_body_res_1)
    if_node.set_output(then_body_res_2, else_body_res_2)
    return if_node
Ejemplo n.º 3
0
def binary_op(op_str, a, b):

    if op_str == "+":
        return a + b
    elif op_str == "Add":
        return ov.add(a, b)
    elif op_str == "-":
        return a - b
    elif op_str == "Sub":
        return ov.subtract(a, b)
    elif op_str == "*":
        return a * b
    elif op_str == "Mul":
        return ov.multiply(a, b)
    elif op_str == "/":
        return a / b
    elif op_str == "Div":
        return ov.divide(a, b)
    elif op_str == "Equal":
        return ov.equal(a, b)
    elif op_str == "Greater":
        return ov.greater(a, b)
    elif op_str == "GreaterEq":
        return ov.greater_equal(a, b)
    elif op_str == "Less":
        return ov.less(a, b)
    elif op_str == "LessEq":
        return ov.less_equal(a, b)
    elif op_str == "Maximum":
        return ov.maximum(a, b)
    elif op_str == "Minimum":
        return ov.minimum(a, b)
    elif op_str == "NotEqual":
        return ov.not_equal(a, b)
    elif op_str == "Power":
        return ov.power(a, b)