def after(annotate_non_call_ops):
        var1 = relay.var("var1", shape=(2,))
        var2 = relay.var("var2", shape=(), dtype="int32")
        var3 = relay.var("var3", shape=(2,))
        var4 = relay.const(10, dtype="int32")

        cb_1 = relay.annotation.compiler_begin(var2, target)
        cb_2 = relay.annotation.compiler_begin(var4, target)

        less_condition = relay.less(cb_1, cb_2)
        ce_1 = relay.annotation.compiler_end(less_condition, target)

        loop = relay.var("while_loop")

        # if condition
        cb_3 = relay.annotation.compiler_begin(var2, target)
        cb_4 = relay.annotation.compiler_begin(relay.const(1, dtype="int32"), target)
        add_op_1 = relay.add(cb_3, cb_4)
        ce_2 = relay.annotation.compiler_end(add_op_1, target)

        cb_5 = relay.annotation.compiler_begin(ce_2, "default") if annotate_non_call_ops else ce_2

        cb_6 = relay.annotation.compiler_begin(var3, target)
        cb_7 = relay.annotation.compiler_begin(var1, target)
        add_op_2 = relay.add(cb_6, cb_7)
        ce_3 = relay.annotation.compiler_end(add_op_2, target)

        cb_8 = relay.annotation.compiler_begin(ce_3, "default") if annotate_non_call_ops else ce_3

        true_branch = loop(cb_5, cb_8)  # while loop
        ce_4 = (
            relay.annotation.compiler_end(true_branch, "default")
            if annotate_non_call_ops
            else true_branch
        )
        if_condition = relay.If(ce_1, ce_4, var3)
        const_1 = relay.const(0, dtype="int32")
        cb_9 = (
            relay.annotation.compiler_begin(const_1, "default")
            if annotate_non_call_ops
            else const_1
        )
        cb_10 = relay.annotation.compiler_begin(var1, target)
        zeros_like = relay.zeros_like(cb_10)
        ce_5 = relay.annotation.compiler_end(zeros_like, target)
        cb_11 = relay.annotation.compiler_begin(ce_5, "default") if annotate_non_call_ops else ce_5
        while_condition = loop(cb_9, cb_11)
        ce_6 = (
            relay.annotation.compiler_end(while_condition, "default")
            if annotate_non_call_ops
            else while_condition
        )

        func_1 = relay.Function([var2, var3], if_condition)
        ret = relay.Let(loop, func_1, ce_6)
        func_2 = relay.Function([var1], ret)
        mod = tvm.IRModule.from_expr(func_2)
        return mod
예제 #2
0
    def expected():
        p0 = relay.var("p0", shape=(), dtype="int32")
        less = relay.less(p0, relay.const(10, dtype="int32"))
        z0 = relay.min(less)
        f0 = relay.Function([p0], z0)
        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

        x = relay.var("x", shape=(), dtype="int32")
        f = relay.Call(f0, [x])
        return relay.Function([x], f)
    def before():

        var1 = relay.var("var1", shape=(2,))
        var2 = relay.var("var2", shape=(), dtype="int32")
        var3 = relay.var("var3", shape=(2,))
        cond = relay.less(var2, relay.const(10, dtype="int32"))

        loop = relay.var("while_loop")
        ii = var2 + relay.const(1, dtype="int32")
        ss = var3 + var1
        true_branch = loop(ii, ss)
        ife = relay.If(cond, true_branch, var3)
        func_1 = relay.Function([var2, var3], ife)

        ret = relay.Let(loop, func_1, loop(relay.const(0, dtype="int32"), relay.zeros_like(var1)))
        func_2 = relay.Function([var1], ret)
        mod = tvm.IRModule.from_expr(func_2)
        return mod
예제 #4
0
 def before():
     x = relay.var("x", shape=(), dtype="int32")
     less = relay.less(x, relay.const(10, dtype="int32"))
     z = relay.min(less)
     return relay.Function([x], z)
예제 #5
0
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from tvm import relay

a = relay.Var("a")
b = relay.expr.const(1.0, dtype='float32')

c = a < b
d = relay.less(a, b)
assert (c.astext() == d.astext())

c = a > b
d = relay.greater(a, b)
assert (c.astext() == d.astext())

c = (a >= b)
d = relay.greater_equal(a, b)
assert (c.astext() == d.astext())

c = (a <= b)
d = relay.less_equal(a, b)
assert (c.astext() == d.astext())