Ejemplo n.º 1
0
def _create_npu_op_unary_elementwise(serial_unary_elementwise):
    operator_type = serial_unary_elementwise.operator_type
    if operator_type == "ABS":
        op = vapi.NpuElementWiseOp.ABS
    if operator_type == "CLZ":
        op = vapi.NpuElementWiseOp.CLZ

    npu_unary_elementwise_op = vapi.NpuElementWiseOperation(op)
    npu_unary_elementwise_op.ifm = _create_npu_feature_map(
        serial_unary_elementwise.ifm)
    npu_unary_elementwise_op.ofm = _create_npu_feature_map(
        serial_unary_elementwise.ofm)

    npu_unary_elementwise_op.activation = _create_npu_activation(
        serial_unary_elementwise.activation)
    if (npu_unary_elementwise_op.activation
            and npu_unary_elementwise_op.activation.op_type
            == vapi.NpuActivationOp.NONE_OR_RELU):
        _convert_clip_bounds(npu_unary_elementwise_op)

    npu_unary_elementwise_op.rounding_mode = _create_npu_rounding_mode(
        serial_unary_elementwise.rounding_mode)
    npu_unary_elementwise_op.block_config = _create_npu_block_config(
        serial_unary_elementwise.block_config)

    if not npu_unary_elementwise_op.block_config:
        target_accel_type = vela_api.get_accelerator_config()
        block_config = vela_api.get_optimal_block_config(
            npu_unary_elementwise_op, target_accel_type)
        npu_unary_elementwise_op.block_config = block_config

    return npu_unary_elementwise_op
Ejemplo n.º 2
0
def _create_npu_op_binary_elementwise(serial_binary_elementwise: spec.SerialBinaryElementwise):
    operator_type = serial_binary_elementwise.operator_type
    if operator_type == "ADD":
        op = vapi.NpuElementWiseOp.ADD
    elif operator_type == "SUB":
        op = vapi.NpuElementWiseOp.SUB
    elif operator_type == "MUL":
        op = vapi.NpuElementWiseOp.MUL
    elif operator_type == "MIN":
        op = vapi.NpuElementWiseOp.MIN
    elif operator_type == "MAX":
        op = vapi.NpuElementWiseOp.MAX
    elif operator_type == "SHR":
        op = vapi.NpuElementWiseOp.SHR
    elif operator_type == "SHL":
        op = vapi.NpuElementWiseOp.SHL

    npu_binary_elementwise_op = vapi.NpuElementWiseOperation(op)
    npu_binary_elementwise_op.ifm = _create_npu_feature_map(serial_binary_elementwise.ifm)
    npu_binary_elementwise_op.ifm2 = _create_npu_feature_map(serial_binary_elementwise.ifm2)
    npu_binary_elementwise_op.ofm = _create_npu_feature_map(serial_binary_elementwise.ofm)
    npu_binary_elementwise_op.reversed_operands = serial_binary_elementwise.reversed_operands

    npu_binary_elementwise_op.activation = _create_npu_activation(
        serial_binary_elementwise.activation
    )
    if (
        npu_binary_elementwise_op.activation
        and npu_binary_elementwise_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU
    ):
        _convert_clip_bounds(npu_binary_elementwise_op)

    npu_binary_elementwise_op.rounding_mode = _create_npu_rounding_mode(
        serial_binary_elementwise.rounding_mode
    )
    npu_binary_elementwise_op.block_config = _create_npu_block_config(
        serial_binary_elementwise.block_config
    )

    if not npu_binary_elementwise_op.block_config:
        target_accel_config = vela_api.get_accelerator_config()
        block_config = vela_api.get_optimal_block_config(
            npu_binary_elementwise_op, target_accel_config
        )
        npu_binary_elementwise_op.block_config = block_config

    return npu_binary_elementwise_op