Example #1
0
def log_softmax(g, input, dim, dtype=None):
    return_op = g.op("LogSoftmax", input, axis_i=dim)
    if dtype and dtype.node().kind() != "prim::Constant":
        parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
        return_op = g.op(
            "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
        )
    return return_op
Example #2
0
def quantize_per_tensor(g, input, scale, zero_point, dtype):
    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
    # TODO(justinchuby): Extract all the cast ops into a helper function.
    zero_point = g.op("Cast",
                      zero_point,
                      to_i=_type_utils.JitScalarType(dtype).onnx_type())
    scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
    return symbolic_helper.quantize_helper(g, input, scale, zero_point)
Example #3
0
 def reduce_dim(g, self, dim, keepdim, dtype):
     if dtype.node().kind() == "onnx::Constant":
         dtype = symbolic_helper._get_const(dtype, "i", "dtype")
         self = g.op(
             "Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()
         )
     elif dtype.node().kind() != "prim::Constant":
         return symbolic_helper._unimplemented(name, "dtype", dtype)
     return symbolic(g, self, dim, keepdim)
Example #4
0
def _constant_fill(g, sizes, dtype: int, const_value):
    if dtype is None:
        scalar_type = _type_utils.JitScalarType.FLOAT
    else:
        scalar_type = _type_utils.JitScalarType(dtype)
    if not scalar_type.dtype().is_floating_point:
        result = g.op(
            "ConstantFill",
            sizes,
            dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
            input_as_shape_i=1,
            value_f=const_value,
        )
        return g.op("Cast", result, to_i=scalar_type.onnx_type())
    else:
        return g.op(
            "ConstantFill",
            sizes,
            dtype_i=scalar_type.onnx_type(),
            input_as_shape_i=1,
            value_f=const_value,
        )