Esempio n. 1
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""L2NormalizeGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

l2_normalize_grad_op_info = TBERegOp("L2NormalizeGrad") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("l2_normalize_grad.so") \
    .compute_cost(10) \
    .kernel_name("l2_normalize_grad") \
    .partial_flag(True) \
    .attr("axis", "required", "listInt", "all") \
    .attr("epsilon", "required", "float", "all") \
    .input(0, "x", False, "required", "all") \
    .input(1, "y", False, "required", "all") \
    .input(2, "dy", False, "required", "all") \
    .output(0, "dx", True, "required", "all") \
    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
    .get_op_info()


@op_info_register(l2_normalize_grad_op_info)
def _l2_normalize_grad_tbe():
    """L2NormalizeGrad TBE register"""
    return
Esempio n. 2
0
bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("bn_training_update.so") \
    .compute_cost(10) \
    .kernel_name("bn_training_update") \
    .partial_flag(True) \
    .attr("factor", "optional", "float", "all") \
    .attr("epsilon", "optional", "float", "all") \
    .attr("isRef", "optional", "bool", "all", "true") \
    .input(0, "x", False, "required", "all", reshape_type="NC") \
    .input(1, "sum", False, "required", "all") \
    .input(2, "square_sum", False, "required", "all") \
    .input(3, "scale", False, "required", "all") \
    .input(4, "offset", False, "required", "all") \
    .input(5, "mean", False, "required", "all") \
    .input(6, "variance", False, "required", "all") \
    .output(0, "y", False, "required", "all", reshape_type="NC") \
    .output(1, "mean", False, "required", "all") \
    .output(2, "variance", False, "required", "all") \
    .output(3, "batch_mean", False, "required", "all") \
    .output(4, "batch_variance", False, "required", "all") \
    .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
                  DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD,
                  DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
    .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
                  DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
                  DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
    .get_op_info()
Esempio n. 3
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""NMSWithMask op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

nms_with_mask_op_info = TBERegOp("NMSWithMask") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("nms_with_mask.so") \
    .compute_cost(10) \
    .kernel_name("nms_with_mask") \
    .partial_flag(True) \
    .attr("iou_threshold", "optional", "float", "all") \
    .input(0, "box_scores", False, "required", "all") \
    .output(0, "selected_boxes", False, "required", "all") \
    .output(0, "selected_idx", False, "required", "all") \
    .output(0, "selected_mask", False, "required", "all") \
    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.U8_Default) \
    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.BOOL_Default) \
    .get_op_info()


@op_info_register(nms_with_mask_op_info)
def _nms_with_mask_tbe():
    """NMSWithMask TBE register"""
    return
Esempio n. 4
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Add op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reciprocal_op_info = TBERegOp("Reciprocal") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("reciprocal.so") \
    .compute_cost(10) \
    .kernel_name("reciprocal") \
    .partial_flag(True) \
    .input(0, "x", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .op_pattern("dynamicFormat") \
    .dtype_format(DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
    .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default) \
    .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
    .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
    .get_op_info()


@op_info_register(reciprocal_op_info)
def _reciprocal_tbe():
    """Add TBE register"""
    return
Esempio n. 5
0
from te import tvm
from topi import generic
from topi.cce import util

# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)

matmul_cube_op_info = TBERegOp("CusMatMulCube") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("matmulcube.so") \
    .compute_cost(10) \
    .kernel_name("CusMatMulCube") \
    .partial_flag(True) \
    .attr("transpose_a", "required", "bool", "all") \
    .attr("transpose_b", "required", "bool", "all") \
    .input(0, "x1", False, "required", "all") \
    .input(1, "x2", False, "required", "all") \
    .input(2, "x3", False, "optional", "all") \
    .output(0, "y", False, "required", "all") \
    .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.F32_FracNZ) \
    .get_op_info()


# pylint: disable=locally-disabled,too-many-arguments,too-many-branches, too-many-statements, too-many-locals,
def _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b):
    """
    Check the given input if legal

    Parameters:
Esempio n. 6
0
matmul_op_info = TBERegOp("MatMul") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("mat_mul.so") \
    .compute_cost(10) \
    .kernel_name("mat_mul") \
    .partial_flag(True) \
    .need_check_supported(True) \
    .attr("transpose_x1", "required", "bool", "all") \
    .attr("transpose_x2", "required", "bool", "all") \
    .attr("offset_x", "optional", "int", "all", "0") \
    .input(0, "x1", False, "required", "all") \
    .input(1, "x2", False, "required", "all") \
    .input(2, "bias", False, "optional", "all") \
    .input(3, "offset_w", False, "optional", "all") \
    .output(0, "y", False, "required", "all") \
    .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default,
                  DataType.I32_Default) \
    .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default,
                  DataType.F16_FracNZ) \
    .dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default,
                  DataType.F32_FracNZ) \
    .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I8_Default,
                  DataType.F32_NHWC) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I8_Default,
                  DataType.F32_Default) \
    .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, DataType.I8_Default,
                  DataType.I32_NHWC) \
    .get_op_info()
Esempio n. 7
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UnsortedSegmentSum op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("unsorted_segment_sum_d.so") \
    .compute_cost(10) \
    .kernel_name("unsorted_segment_sum_d") \
    .partial_flag(True) \
    .attr("num_segments", "required", "int", "all") \
    .input(0, "x", False, "required", "all") \
    .input(1, "segment_ids", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .is_dynamic_format(True) \
    .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
    .get_op_info()


@op_info_register(unsorted_segment_sum_op_info)
def _unsorted_segment_sum_tbe():
    """UnsortedSegmentSum TBE register"""
    return
Esempio n. 8
0
from te import tvm
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \
    .fusion_type("ELEMWISE") \
    .async_flag(False) \
    .binfile_name("fake_quant_per_layer.so") \
    .compute_cost(10) \
    .kernel_name("fake_quant_per_layer") \
    .partial_flag(True) \
    .attr("symmetric", "optional", "bool", "all") \
    .attr("narrow_range", "optional", "bool", "all") \
    .attr("num_bits", "optional", "int", "all") \
    .input(0, "x", None, "required", None) \
    .input(1, "min", None, "required", None) \
    .input(2, "max", None, "required", None) \
    .output(0, "y", True, "required", "all") \
    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
    .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
    .get_op_info()


@op_info_register(fake_quant_op_info)
def _fake_quant_per_layer_tbe():
    """FakeQuantPerLayer TBE register"""
    return
Esempio n. 9
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""AtanGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

atan_grad_op_info = TBERegOp("AtanGrad") \
    .fusion_type("ELEMWISE") \
    .async_flag(False) \
    .binfile_name("atan_grad.so") \
    .compute_cost(10) \
    .kernel_name("atan_grad") \
    .partial_flag(True) \
    .input(0, "y", False, "required", "all") \
    .input(1, "dy", False, "required", "all") \
    .output(0, "z", False, "required", "all") \
    .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
    .dtype_format(DataType.F16_FracZ, DataType.F16_FracNZ, DataType.F16_FracZ) \
    .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
    .dtype_format(DataType.F32_FracZ, DataType.F32_FracNZ, DataType.F32_FracZ) \
    .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
    .get_op_info()


@op_info_register(atan_grad_op_info)
def _atan_grad_tbe():
    """AtanGrad TBE register"""
    return
Esempio n. 10
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BiasAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

bias_add_grad_op_info = TBERegOp("BiasAdd") \
    .fusion_type("COMMREDUCE") \
    .async_flag(False) \
    .binfile_name("bias_add.so") \
    .compute_cost(10) \
    .kernel_name("bias_add") \
    .partial_flag(True) \
    .attr("data_format", "required", "str", "all") \
    .input(0, "x", False, "required", "all") \
    .input(1, "bias", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .op_pattern("dynamicFormat") \
    .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
    .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
    .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
    .get_op_info()


@op_info_register(bias_add_grad_op_info)
def _bias_add_tbe():
    """BiasAdd TBE register"""
    return
Esempio n. 11
0
apply_ftrl_op_info = TBERegOp("ApplyFtrl") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("apply_ftrl.so") \
    .compute_cost(10) \
    .kernel_name("apply_ftrl") \
    .partial_flag(True) \
    .input(0, "var", False, "required", "all") \
    .input(1, "accum", False, "required", "all") \
    .input(2, "linear", False, "required", "all") \
    .input(3, "grad", False, "required", "all") \
    .input(4, "lr", False, "required", "all") \
    .input(5, "l1", False, "required", "all") \
    .input(6, "l2", False, "required", "all") \
    .input(7, "lr_power", False, "required", "all") \
    .output(0, "var", False, "required", "all") \
    .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
                  DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
                  DataType.F16_5HD) \
    .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
                  DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
                  DataType.F16_FracZ) \
    .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
                  DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
                  DataType.F16_C1HWNCoC0) \
    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
                  DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
                  DataType.F16_Default) \
    .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
                  DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
                  DataType.F32_5HD) \
    .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
                  DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
                  DataType.F32_FracZ) \
    .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
                  DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
                  DataType.F32_C1HWNCoC0) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
                  DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
                  DataType.F32_Default) \
    .get_op_info()
Esempio n. 12
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""SoftplusGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

softplus_grad_op_info = TBERegOp("SoftplusGrad") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("softplus_grad.so") \
    .compute_cost(10) \
    .kernel_name("softplus_grad") \
    .partial_flag(True) \
    .op_pattern("broadcast") \
    .input(0, "gradients", False, "required", "all") \
    .input(1, "features", False, "required", "all") \
    .output(0, "backprops", False, "required", "all") \
    .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
    .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
    .get_op_info()


@op_info_register(softplus_grad_op_info)
def _softplus_grad_tbe():
    """SoftplusGrad TBE register"""
    return
Esempio n. 13
0
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
    .fusion_type("COMMREDUCE") \
    .async_flag(False) \
    .binfile_name("bias_add_grad.so") \
    .compute_cost(10) \
    .kernel_name("bias_add_grad") \
    .partial_flag(True) \
    .attr("format", "required", "str", "all") \
    .input(0, "output_backprop", False, "required", "all") \
    .output(0, "output", False, "required", "all") \
    .dtype_format(DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default) \
    .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
    .dtype_format(DataType.F16_FracNZ, DataType.F16_NHWC) \
    .dtype_format(DataType.F32_FracNZ, DataType.F32_NHWC) \
    .dtype_format(DataType.F16_Default, DataType.F16_NHWC) \
    .dtype_format(DataType.F32_Default, DataType.F32_NHWC) \
    .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_Default) \
    .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_Default) \
    .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NHWC) \
    .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NHWC) \
    .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_Default) \
    .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_Default) \
    .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NHWC) \
    .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NHWC) \
    .get_op_info()

Esempio n. 14
0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Invert op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

invert_op_info = TBERegOp("Invert") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("invert.so") \
    .compute_cost(10) \
    .kernel_name("invert") \
    .partial_flag(True) \
    .input(0, "x", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .op_pattern("formatAgnostic") \
    .dtype_format(DataType.I16_None, DataType.I16_None) \
    .dtype_format(DataType.U16_None, DataType.U16_None) \
    .get_op_info()


@op_info_register(invert_op_info)
def _invert_tbe():
    """Invert TBE register"""
    return
Esempio n. 15
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReluGradV2 op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

relu_grad_v2_op_info = TBERegOp("ReluGradV2") \
    .fusion_type("ELEMWISE") \
    .async_flag(False) \
    .binfile_name("relu_grad_v2.so") \
    .compute_cost(10) \
    .kernel_name("relu_grad_v2") \
    .partial_flag(True) \
    .input(0, "gradients", False, "required", "all") \
    .input(1, "mask", False, "required", "all") \
    .output(0, "backprops", True, "required", "all") \
    .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \
    .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \
    .dtype_format(DataType.I32_5HD, DataType.U8_Default, DataType.I32_5HD) \
    .dtype_format(DataType.I8_5HD, DataType.U8_Default, DataType.I8_5HD) \
    .dtype_format(DataType.U8_5HD, DataType.U8_Default, DataType.U8_5HD) \
    .get_op_info()


@op_info_register(relu_grad_v2_op_info)
def _relu_grad_v2_tbe():
    """ReluGradV2 TBE register"""
    return
Esempio n. 16
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceMean op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

reduce_mean_op_info = TBERegOp("ReduceMean") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("reduce_mean.so") \
    .compute_cost(10) \
    .kernel_name("reduce_mean") \
    .partial_flag(True) \
    .attr("axis", "optional", "listInt", "all") \
    .attr("keep_dims", "optional", "bool", "all") \
    .input(0, "x", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .dtype_format(DataType.I8_Default, DataType.I8_Default) \
    .dtype_format(DataType.U8_Default, DataType.U8_Default) \
    .dtype_format(DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default) \
    .get_op_info()


@op_info_register(reduce_mean_op_info)
def _reduce_mean_tbe():
    """ReduceMean TBE register"""
    return
Esempio n. 17
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InplaceUpdate op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

inplace_update_op_info = TBERegOp("InplaceUpdate") \
    .fusion_type("INPLACE") \
    .async_flag(False) \
    .binfile_name("inplace_update_d.so") \
    .compute_cost(10) \
    .kernel_name("inplace_update_d") \
    .partial_flag(True) \
    .need_check_supported(True) \
    .attr("indices", "required", "listInt", "all") \
    .input(0, "x", False, "required", "all") \
    .input(1, "v", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
    .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
    .get_op_info()


@op_info_register(inplace_update_op_info)
def _inplace_update_tbe():
    """InplaceUpdate TBE register"""
    return
Esempio n. 18
0
trans_data_op_info = TBERegOp("TransData") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("trans_data.so") \
    .compute_cost(10) \
    .kernel_name("trans_data") \
    .partial_flag(True) \
    .attr("src_format", "required", "str", "DefaultFormat,NC1HWC0,FracZ,FRACTAL_NZ,HWCN,C1HWNCoC0")\
    .attr("dst_format", "required", "str", "DefaultFormat,NC1HWC0,FracZ,FRACTAL_NZ,HWCN,C1HWNCoC0")\
    .input(0, "src", False, "required", "all") \
    .output(0, "dst", False, "required", "all") \
    .dtype_format(DataType.U16_Default, DataType.U16_5HD) \
    .dtype_format(DataType.U16_Default, DataType.U16_FracZ) \
    .dtype_format(DataType.U16_Default, DataType.U16_FracNZ) \
    .dtype_format(DataType.U16_FracZ, DataType.U16_Default) \
    .dtype_format(DataType.U16_FracZ, DataType.U16_HWCN) \
    .dtype_format(DataType.U16_FracNZ, DataType.U16_Default) \
    .dtype_format(DataType.U16_5HD, DataType.U16_Default) \
    .dtype_format(DataType.U16_HWCN, DataType.U16_FracZ) \
    .dtype_format(DataType.U16_HWCN, DataType.U16_C1HWNCoC0) \
    .dtype_format(DataType.U16_C1HWNCoC0, DataType.U16_HWCN) \
    .dtype_format(DataType.BOOL_Default, DataType.BOOL_5HD) \
    .dtype_format(DataType.F16_Default, DataType.F16_5HD) \
    .dtype_format(DataType.F16_Default, DataType.F16_FracZ) \
    .dtype_format(DataType.F16_Default, DataType.F16_FracNZ) \
    .dtype_format(DataType.F16_FracZ, DataType.F16_Default) \
    .dtype_format(DataType.F16_FracZ, DataType.F16_HWCN) \
    .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
    .dtype_format(DataType.F16_5HD, DataType.F16_Default) \
    .dtype_format(DataType.F16_HWCN, DataType.F16_FracZ) \
    .dtype_format(DataType.F16_HWCN, DataType.F16_C1HWNCoC0) \
    .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_HWCN) \
    .dtype_format(DataType.F32_Default, DataType.F32_5HD) \
    .dtype_format(DataType.F32_Default, DataType.F32_FracZ) \
    .dtype_format(DataType.F32_Default, DataType.F32_FracNZ) \
    .dtype_format(DataType.F32_FracZ, DataType.F32_Default) \
    .dtype_format(DataType.F32_FracZ, DataType.F32_HWCN) \
    .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
    .dtype_format(DataType.F32_5HD, DataType.F32_Default) \
    .dtype_format(DataType.F32_HWCN, DataType.F32_FracZ) \
    .dtype_format(DataType.F32_HWCN, DataType.F32_C1HWNCoC0) \
    .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_HWCN) \
    .dtype_format(DataType.F32_Default, DataType.F32_NCHW) \
    .dtype_format(DataType.F32_HWCN, DataType.F32_Default) \
    .get_op_info()
Esempio n. 19
0
# ============================================================================
"""Conv3DTranspose op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

conv3d_transpose_op_info = TBERegOp("Conv3DTranspose") \
    .fusion_type("CONVLUTION") \
    .async_flag(False) \
    .binfile_name("conv3d_transpose_d.so") \
    .compute_cost(10) \
    .kernel_name("conv3d_transpose_d") \
    .partial_flag(True) \
    .attr("input_size", "required", "listInt", "all") \
    .attr("strides", "required", "listInt", "all") \
    .attr("pad_list", "required", "listInt", "all") \
    .attr("dilations", "optional", "listInt", "all") \
    .attr("groups", "optional", "int", "all") \
    .attr("format", "optional", "str", "all") \
    .attr("output_padding", "optional", "listInt", "all") \
    .attr("offset_x", "optional", "int", "all", "0") \
    .input(0, "x", False, "required", "all") \
    .input(1, "filter", False, "required", "all") \
    .input(2, "bias", False, "optional", "all") \
    .input(3, "offset_w", False, "optional", "all") \
    .output(0, "y", True, "required", "all") \
    .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, DataType.I8_Default,
                  DataType.F16_NDC1HWC0) \
    .get_op_info()


@op_info_register(conv3d_transpose_op_info)
def _conv3d_transpose_tbe():
Esempio n. 20
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pow op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

pow_op_info = TBERegOp("Pow") \
    .fusion_type("ELEMWISE") \
    .async_flag(False) \
    .binfile_name("pow.so") \
    .compute_cost(10) \
    .kernel_name("pow") \
    .partial_flag(True) \
    .input(0, "x1", False, "required", "all") \
    .input(1, "x2", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .op_pattern("broadcast") \
    .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \
    .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \
    .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
    .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
    .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
    .get_op_info()


@op_info_register(pow_op_info)
def _pow_tbe():
    """Pow TBE register"""
    return
Esempio n. 21
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

pad_d_op_info = TBERegOp("Pad") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("pad_d.so") \
    .compute_cost(10) \
    .kernel_name("pad_d") \
    .partial_flag(True) \
    .attr("paddings", "optional", "listListInt", "all") \
    .input(0, "x", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .dtype_format(DataType.I8_Default, DataType.I8_Default) \
    .dtype_format(DataType.U8_Default, DataType.U8_Default) \
    .dtype_format(DataType.I32_Default, DataType.I32_Default) \
    .dtype_format(DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default) \
    .get_op_info()


@op_info_register(pad_d_op_info)
def _pad_d_tbe():
    """Pad TBE register"""
    return
Esempio n. 22
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Add op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

split_d_op_info = TBERegOp("Split") \
    .fusion_type("ELEMWISE") \
    .async_flag(False) \
    .binfile_name("split_d.so") \
    .compute_cost(10) \
    .kernel_name("split_d") \
    .partial_flag(True) \
    .attr("axis", "required", "int", "all") \
    .attr("output_num", "required", "int", "all") \
    .input(0, "value", False, "required", "all") \
    .output(0, "output", False, "dynamic", "all") \
    .op_pattern("dynamicFormat") \
    .dtype_format(DataType.None_None, DataType.None_None) \
    .get_op_info()


@op_info_register(split_d_op_info)
def _split_d_tbe():
    """Add TBE register"""
    return
Esempio n. 23
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterNdUpdate op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

scatter_nd_update_op_info = TBERegOp("ScatterNdUpdate") \
    .fusion_type("ELEMWISE") \
    .async_flag(False) \
    .binfile_name("scatter_nd_update.so") \
    .compute_cost(10) \
    .kernel_name("scatter_nd_update") \
    .partial_flag(True) \
    .attr("use_locking", "optional", "bool", "all") \
    .input(0, "var", False, "required", "all") \
    .input(1, "indices", False, "required", "all") \
    .input(1, "updates", False, "required", "all") \
    .output(0, "var", False, "required", "all") \
    .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
    .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
    .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default,) \
    .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
    .get_op_info()


@op_info_register(scatter_nd_update_op_info)
def _scatter_nd_update_tbe():
    """ScatterNdUpdate TBE register"""
    return
Esempio n. 24
0
fused_mul_apply_momentum_op_info = TBERegOp("FusedMulApplyMomentum") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("fused_mul_apply_momentum.so") \
    .compute_cost(10) \
    .kernel_name("fused_mul_apply_momentum") \
    .partial_flag(True) \
    .attr("use_nesterov", "optional", "bool", "true,false", "false") \
    .input(0, "var", False, "required", "all") \
    .input(1, "accum", False, "required", "all") \
    .input(2, "lr", False, "required", "all") \
    .input(3, "x1", False, "required", "all") \
    .input(4, "momentum", False, "required", "all") \
    .input(5, "x2", False, "required", "all") \
    .output(0, "var", False, "required", "all") \
    .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD,
                  DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \
    .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_C1HWNCoC0,
                  DataType.F16_Default, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
                  DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
    .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
                  DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
    .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
                  DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
    .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_C1HWNCoC0,
                  DataType.F32_Default, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
    .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
                  DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
    .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
                  DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
    .get_op_info()
Esempio n. 25
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResizeNearestNeighbor op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

resize_nearest_neighbor_op_info = TBERegOp("ResizeNearestNeighbor") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("resize_nearest_neighbor_d.so") \
    .compute_cost(10) \
    .kernel_name("resize_nearest_neighbor_v2_d") \
    .partial_flag(True) \
    .attr("size", "required", "listInt", "all") \
    .attr("align_corners", "optional", "bool", "all") \
    .input(0, "images", False, "required", "all") \
    .output(0, "y", True, "required", "all") \
    .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
    .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
    .get_op_info()


@op_info_register(resize_nearest_neighbor_op_info)
def _resize_nearest_neighbor_tbe():
    """ResizeNearestNeighbor TBE register"""
    return
Esempio n. 26
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Neg op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

neg_op_info = TBERegOp("Neg") \
    .fusion_type("ELEMWISE") \
    .async_flag(False) \
    .binfile_name("neg.so") \
    .compute_cost(10) \
    .kernel_name("neg") \
    .partial_flag(True) \
    .input(0, "x", False, "required", "all") \
    .output(0, "y", False, "required", "all") \
    .op_pattern("formatAgnostic") \
    .dtype_format(DataType.I32_None, DataType.I32_None) \
    .dtype_format(DataType.F16_None, DataType.F16_None) \
    .dtype_format(DataType.F32_None, DataType.F32_None) \
    .dtype_format(DataType.I8_None, DataType.I8_None) \
    .get_op_info()


@op_info_register(neg_op_info)
def _neg_tbe():
    """Neg TBE register"""
    return
Esempio n. 27
0
from te.platform.fusion_manager import fusion_manager
from topi import generic
from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

SHAPE_SIZE_LIMIT = 2147483648

correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
    .fusion_type("OPAQUE") \
    .async_flag(False) \
    .binfile_name("correction_mul_grad.so") \
    .compute_cost(10) \
    .kernel_name("correction_mul_grad") \
    .partial_flag(True) \
    .attr("channel_axis", "optional", "int", "all") \
    .input(0, "dout", None, "required", None) \
    .input(1, "x", None, "required", None) \
    .input(2, "batch_std", None, "required", None) \
    .input(3, "running_std", None, "required", None) \
    .output(0, "dx", True, "required", "all") \
    .output(1, "mul_dx", True, "required", "all") \
    .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
                  DataType.F32_5HD, DataType.F32_5HD) \
    .get_op_info()


@op_info_register(correction_mul_grad_op_info)
def _correction_mul_grad_tbe():
    """CorrectionMulGrad TBE register"""
    return