def _try_cast_integer_to_float(g, *args): floating_scalar_types = ['Half', 'Float', 'Double'] old_type = None # Cast the input tensor to Float if its scalarType is known and is not floating number. # If casting is performed, return the old scalarType, otherwise return None. if args[0].type().kind() == "DimensionedTensorType" or args[0].type().kind() == "CompleteTensorType": old_type = args[0].type().scalarType() if old_type not in floating_scalar_types: args = tuple(_cast_Float(g, arg, False) for arg in args) else: return (None,) + args else: warnings.warn("Only floating datatype is supported for these operators: " "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause " "the onnx model to be incorrect, if inputs have integer datatypes.") return (old_type,) + args
def _try_cast_integer_to_float(g, *args): floating_scalar_types = ["Half", "Float", "Double"] old_type = None # Cast the input tensor to Float if its scalarType is known and is not floating number. # If casting is performed, return the old scalarType, otherwise return None. arg0_type = args[0].type().scalarType() if arg0_type is not None: old_type = arg0_type if old_type not in floating_scalar_types: # TODO(justinchuby): Remove the type ignore hint once _cast_Float is # properly defined. # NOTE: _cast_Float is generated programmatically so we need to make the # type checker happy with ignore[attr-defined]. args = tuple(opset9._cast_Float(g, arg, False) for arg in args) # type: ignore[attr-defined] else: return (None,) + args else: warnings.warn( "Only floating datatype is supported for these operators: " "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause " "the onnx model to be incorrect, if inputs have integer datatypes." ) return (old_type,) + args
def nonzero(g, input): return t(g, g.op('NonZero', _cast_Float(g, input, False)))