def my_custom_lowering(inputs, out_shape, out_type, device):
            def compute(idxs):
                load = inputs[0].as_buf().load(idxs)
                return te.ifThenElse(te.ExprHandle.isnan(load),
                                     te.ExprHandle.float(0.0), load)

            return te.Compute2("custom_nan_to_num", out_shape, compute)
Beispiel #2
0
        def my_custom_lowering(inputs, out_shape, out_type, device):
            def get_dim_args(dims):
                dim_args = []
                for dim in dims:
                    dim_args.append(te.DimArg(dim, 'i' + str(len(dim_args))))
                return dim_args

            def compute(idxs):
                load = inputs[0].as_buf().load(idxs)
                return te.ifThenElse(te.ExprHandle.isnan(load),
                                     te.ExprHandle.float(0.), load)

            return te.Compute2("custom_nan_to_num", get_dim_args(out_shape),
                               compute)