def nan_to_num(g, input, nan, posinf, neginf): from torch.onnx.symbolic_opset9 import isnan, lt, gt, logical_and # Cannot create a int type tensor with inf/nan values, so we simply # return the original tensor if not sym_help._is_fp(input): return input input_dtype = sym_help.pytorch_name_to_type[input.type().scalarType()] if nan is None: nan = 0.0 nan_cond = isnan(g, input) nan_result = g.op("Where", nan_cond, g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), input) # For None values of posinf, neginf we use the greatest/lowest finite # value representable by input’s dtype. finfo = torch.finfo(input_dtype) if posinf is None: posinf = finfo.max posinf_cond = logical_and(g, isinf(g, nan_result), gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0])))) nan_posinf_result = g.op("Where", posinf_cond, g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), nan_result) if neginf is None: neginf = finfo.min neginf_cond = logical_and(g, isinf(g, nan_posinf_result), lt(g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])))) return g.op("Where", neginf_cond, g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), nan_posinf_result)
def nan_to_num(g, input, nan, posinf, neginf): # Cannot create a int type tensor with inf/nan values, so we simply # return the original tensor if not symbolic_helper._is_fp(input): return input input_dtype = _type_utils.JitScalarType.from_name( input.type().scalarType()).dtype() if nan is None: nan = 0.0 nan_cond = opset9.isnan(g, input) nan_result = g.op( "Where", nan_cond, g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), input, ) # For None values of posinf, neginf we use the greatest/lowest finite # value representable by input’s dtype. finfo = torch.finfo(input_dtype) if posinf is None: posinf = finfo.max posinf_cond = opset9.logical_and( g, isinf(g, nan_result), opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), ) nan_posinf_result = g.op( "Where", posinf_cond, g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), nan_result, ) if neginf is None: neginf = finfo.min neginf_cond = opset9.logical_and( g, isinf(g, nan_posinf_result), opset9.lt(g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))), ) return g.op( "Where", neginf_cond, g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), nan_posinf_result, )