Exemplo n.º 1
0
def main():
    a = -1
    b = 5.0
    c = 2
    d = 6

    for k, v in {"a": a, "b": b, "c": c, "d": d}.items():
        print(k + " =", v)

    print("\nPositive/Negative:")
    print("abs(a);", operator.abs(a))
    print("neg(a);", operator.neg(a))
    print("neg(b);", operator.neg(b))
    print("pos(a);", operator.pos(a))
    print("pos(b);", operator.pos(b))

    print("\nArithmetic:")
    print("add(a, b)     :", operator.add(a, b))
    print("floordiv(a, b):", operator.floordiv(a, b))
    print("floordiv(d, c):", operator.floordiv(d, c))
    print("mod(a, b)     :", operator.mod(a, b))
    print("mul(a, b)     :", operator.mul(a, b))
    print("pow(c, d)     :", operator.pow(c, d))
    print("sub(b, a)     :", operator.sub(b, a))
    print("truediv(a, b) :", operator.truediv(a, b))
    print("truediv(d, c) :", operator.truediv(d, c))

    print("\nBitwise:")
    print("and_(c, d)  :", operator.and_(c, d))
    print("invert(c)   :", operator.invert(c))
    print("lshift(c, d):", operator.lshift(c, d))
    print("or_(c, d)   :", operator.or_(c, d))
    print("rshift(d, c):", operator.rshift(d, c))
    print("xor(c, d)   :", operator.xor(c, d))
Exemplo n.º 2
0
    def _expression_to_criterion(self, expr, param_index):
        values = []
        criterion = EmptyCriterion()
        if expr.children:
            for sub_expression in expr.children:
                sub_criterion, sub_values = self._expression_to_criterion(
                    sub_expression, param_index)
                criterion = self._join_criterion(criterion, sub_criterion,
                                                 expr.join_type)
                param_index += len(sub_values)
                values.extend(sub_values)
        else:
            for key, value in expr.filters.items():
                fn, value = FunctionResolve.resolve_value(
                    value, self.pika_table)
                param = fn if fn else self.parameter(param_index)

                sub_criterion, value = self.filters.get_criterion(
                    key, param, value)
                criterion = self._join_criterion(criterion, sub_criterion,
                                                 expr.join_type)

                if value != None:
                    param_index += 1
                    values.append(value)

        if expr._is_negated and not isinstance(criterion, EmptyCriterion):
            criterion = operator.invert(criterion)
        return criterion, values
Exemplo n.º 3
0
def test_enum_flags():
    """Test using the enum flags.

    """
    # Test using flags instances
    flags_cls = EnumTest.Flags
    assert EnumTest.Flags is flags_cls
    assert flags_cls(0) == 0
    flag = flags_cls('a')
    assert flag == 1
    assert flag is flags_cls(flag)
    assert 'enumflags' in repr(flag)
    assert str(flag) == 'EnumTestFlags'

    # Test logic operations
    assert and_(flag, EnumTest.a)
    assert or_(flags_cls('b'), flag)
    assert xor(flags_cls('b'), flag) == 3
    with pytest.raises(TypeError):
        and_(flag, 2)
    assert invert(flag) == -2
Exemplo n.º 4
0
def test_enum_flags():
    """Test using the enum flags.

    """
    # Test using flags instances
    flags_cls = EnumTest.Flags
    assert EnumTest.Flags is flags_cls
    assert flags_cls(0) == 0
    flag = flags_cls('a')
    assert flag == 1
    assert flag is flags_cls(flag)
    assert 'enumflags' in repr(flag)
    assert str(flag) == 'EnumTestFlags'

    # Test logic operations
    assert and_(flag, EnumTest.a)
    assert or_(flags_cls('b'), flag)
    assert xor(flags_cls('b'), flag) == 3
    with pytest.raises(TypeError):
        and_(flag, 2)
    assert invert(flag) == -2
Exemplo n.º 5
0
def run301_03():
    """
    arithmetic ops
    :return:
    """
    a = -1
    b = 5.0
    c = 2
    d = 6

    print('a =', a)
    print('b =', b)
    print('c =', c)
    print('d =', d)

    print('\nPositive/Negative:')
    print('abs(a):', abs(a))
    print('neg(a):', neg(a))
    print('neg(b):', neg(b))
    print('pos(a):', pos(a))
    print('pos(b):', pos(b))

    print('\nArithmetic:')
    print('add(a, b)     :', add(a, b))
    print('floordiv(a, b):', floordiv(a, b))  # for py2
    print('floordiv(d, c):', floordiv(d, c))  # for py2
    print('mod(a, b)     :', mod(a, b))
    print('mul(a, b)     :', mul(a, b))
    print('pow(c, d)     :', pow(c, d))
    print('sub(b, a)     :', sub(b, a))
    print('truediv(a, b) :', truediv(a, b))
    print('truediv(d, c) :', truediv(d, c))

    print('\nBitwise:')
    print('and_(c, d)  :', and_(c, d))
    print('invert(c)   :', invert(c))  # ~c
    print('lshift(c, d):', lshift(c, d))  # c << d
    print('or_(c, d)   :', or_(c, d))
    print('rshift(d, c):', rshift(d, c))  # d >> c
    print('xor(c, d)   :', xor(c, d))  # 不同得1 ^
Exemplo n.º 6
0
def operator_arithmetic():
    """
    arithmetic, math
    +, -, |a|
    +, -, *, //, /, %, pow
    &, |, ^, ~, <<, >>
    """
    a, b, c, d = -1, 5.0, 2, 6
    print('a =', a)
    print('b =', b)
    print('c =', c)
    print('d =', d)

    print('\nPositive/Negative:')
    print('abs(a):', operator.abs(a))
    print('neg(a):', operator.neg(a))
    print('neg(b):', operator.neg(b))
    print('pos(a):', operator.pos(a))
    print('pos(b):', operator.pos(b))

    print('\nArithmetic:')
    print('add(a, b)        :', operator.add(a, b))
    print('sub(b, a)        :', operator.sub(b, a))
    print('mul(a, b)        :', operator.mul(a, b))
    print('truediv(a, b)    :', operator.truediv(a, b))
    print('truediv(d, c)    :', operator.truediv(d, c))
    print('floordiv(a, b)   :', operator.floordiv(a, b))
    print('pow(c, d)        :', operator.pow(c, d))
    print('mod(a, b)        :', operator.mod(a, b))

    print('\nBitwise:')
    print('and_(c, d)       :', operator.and_(c, d))  # c & d
    print('or_(c, d)        :', operator.or_(c, d))  # c | d
    print('invert(c)        :',
          operator.invert(c))  # two's complement. ~c = -c - 1
    print('xor(c, d)        :', operator.xor(c, d))  # a ^ b
    print('lshift(c, d)     :', operator.lshift(c, d))  # d << c
    print('rshift(d, c)     :', operator.rshift(d, c))  # d >> c
Exemplo n.º 7
0
print(" bit and ", oct(10), oct(2))
print(" &           : ", 10 & 2)
print(" and_(a, b)  : ", op.and_(10,2))
print("x.__and__(y) : ", (10).__and__(2))
print(" bit or ", oct(10), oct(4))
print(" |           : ", 10 | 4)
print(" or_(a, b)   : ", op.or_(10,4))
print("x.__or__(y)  : ", (10).__or__(4))
print(" bit xor ", oct(10), oct(6))
print(" ^            : ", 10 ^ 6)
print(" xor(a, b)    : ", op.xor(10,6))
print("x.__xor__(y)  : ", (10).__xor__(6))

print(" bit inversion ", oct(10))
print(" ~           : ", ~(10) )
print(" invert(a)   : ", op.invert(10))
print("x.__invert__(y)  : ", (10).__invert__())

print(" negative : ", -(10))
print(" neg(a)   : ", op.neg((10)))
print("a.__neg__   : ", (10).__neg__())
print(" positive : ", -(-10), +(10))
print(" pos(a)   : ", op.pos((10)))
print("a.__pos__   : ", (10).__pos__())

print(" right hand operator ")
print(" x + y ", (8).__radd__(2))
print(" x + y ", (2).__add__(8))

print(" x ** y ", (3).__rpow__(2))
print(" x ** y ", (2).__pow__(3))
Exemplo n.º 8
0
#!/usr/bin/env python
#-*- coding:utf8 -*-

import operator as op

x = 10

print str(op.invert(x))
print str(~x)

Exemplo n.º 9
0
print('a = ', a)
print('b = ', b)
print('c = ', c)
print('d = ', d)

print('\nPositive/Negative:')
print(f'abs({a}):', abs(a))
print(f'neg({a}):', neg(a))
print(f'neg({b}):', neg(b))
print(f'pos({a}):', pos(a))
print(f'pos({b}):', pos(b))

print('\nArithmetic:')
print(f'add({a}, {b}):', add(a, b))
print(f'floordiv({a}, {b}):', floordiv(a, b))
print(f'mod({a},{b}): ', mod(a, b))
print(f'mul({a},{b}): ', mul(a, b))
print(f'pow({c},{d}):', pow(c, d))
print(f'sub({b},{a}):', sub(b, a))
print(f'truediv({a},{b}):', truediv(a, b))
print(f'truediv({d},{c}):', truediv(d, c))

print('\nBitwise:')
print(f'and_({c}, {d}) :', and_(c, d))
print(f'invert({c}) :', invert(c))
print(f'lshift({c}, {d}) :', lshift(c, d))
print(f'or_({c}, {d}) :', or_(c, d))
print(f'rshift({d}, {c}) :', rshift(d, c))
print(f'xor({c}, {d}) :', xor(c, d))

Exemplo n.º 10
0
 def __invert__(self):
     return operator.invert(self.data)
Exemplo n.º 11
0
 def bitwise_not_usecase(x):
     return operator.invert(x)
Exemplo n.º 12
0
def op_invert(x): return _op.invert(x)

@cutype("a -> a")
Exemplo n.º 13
0
print('operator.abs(a):', op.abs(a))
print('operator.neg(a):', op.neg(a))
print('operator.neg(b):', op.neg(b))
print('operator.pos(a):', op.pos(a))
print('operator.pos(b):', op.pos(b))
print()

a, b = -2, 5.0
print('a =', a)
print('b =', b)
print('operator.add(a, b):', op.add(a, b))
# print('operator.div(a, b):', op.div(a, b))
print('operator.floordiv(a, b):', op.floordiv(a, b))
print('operator.mod(a, b):', op.mod(a, b))
print('operator.mul(a, b):', op.mul(a, b))
print('operator.pow(a, b):', op.pow(a, b))
print('operator.sub(a, b):', op.sub(a, b))
print('operator.truediv(a, b):', op.truediv(a, b))
print()

a, b = 2, 6
print('a =', a)
print('b =', b)
print('operator.and_(a, b):', op.and_(a, b))
print('operator.invert(a, b):', op.invert(a))
print('operator.lshift(a, b):', op.lshift(a, b))
print('operator.or_(a, b):', op.or_(a, b))
print('operator.rshift(a, b):', op.rshift(a, b))
print('operator.xor(a, b):', op.xor(a, b))
print()
Exemplo n.º 14
0
 def bitwise_not_usecase_binary(x, _unused):
     return operator.invert(x)
Exemplo n.º 15
0
def op_invert(x):
    return _op.invert(x)
Exemplo n.º 16
0
 def execute_TILDE(self):
     number = self.stack.pop()
     self.stack.push(operator.invert(int(number)))
Exemplo n.º 17
0
 def __invert__(self):
     return operator.invert(self.data)
Exemplo n.º 18
0
print("neg(b):", operator.neg(b))
print("pos(a):", operator.pos(a))
print("pos(b):", operator.pos(b))
print("\nArithmetic:")
print("add(a, b) :", operator.add(a, b))
print("floordiv(a, b):", operator.floordiv(a, b))
print("floordiv(d, c):", operator.floordiv(d, c))
print("mod(a, b) :", operator.mod(a, b))
print("mul(a, b) :", operator.mul(a, b))
print("pow(c, d) :", operator.pow(c, d))
print("sub(b, a) :", operator.sub(b, a))
print("truediv(a, b) :", operator.truediv(a, b))
print("truediv(d, c) :", operator.truediv(d, c))
print("\nBitwise:")
print("and_(c, d) :", operator.and_(c, d))
print("invert(c) :", operator.invert(c))
print("lshift(c, d):", operator.lshift(c, d))
print("or_(c, d) :", operator.or_(c, d))
print("rshift(d, c):", operator.rshift(d, c))

# ------------------------------------------------------------------------------
# Sequence Operators
# The operatirs for working with sequence can be organized into four groups:
# building up sequences, searching for items, accessing contents, and removing
# items from sequences.

a = [1, 2, 3]
b = ["a", "b", "c"]
print("a =", a)
print("b =", b)
Exemplo n.º 19
0
 def update_event(self, inp=-1):
     self.set_output_val(0, operator.invert(self.input(0)))
 def execute_TILDE(self):
   number = self.stack.pop()
   self.stack.push(operator.invert(int(number)))
Exemplo n.º 21
0
l1 = [1, 2, 3, 4]  ## indexing starting from 0
## getitem
print(op.getitem(l1, 3))
print(op.getitem(l1, slice(1, 4)))  ##  slice(a,b)  -- [a,b)

## setitem
op.setitem(l1, 2, 6)
op.setitem(l1, slice(1, 2), [4])

## delitem
op.delitem(l1, 1)
op.delitem(l1, slice(2, 3))

## travessing the list
print("list")
for x in l1:
    print(x)

l2 = [90, 23, 1]
l3 = op.concat(l1, l2)
print(op.contains(l3, l1))

######### bitwise operation

a = 1
b = 0
print(op.and_(a, b))
print(op.or_(a, b))
print(op.xor(a, b))
print(op.invert(a))
Exemplo n.º 22
0
#精确除法
print operator.truediv(9.04,4)

#绝对值
print operator.abs(-10)

#取反 相当于 -a
print operator.neg(-10)

#取反 相当于 ~a  
#~a = (-a)-1
#~10 = -11
#~(-10) = 9
print operator.inv(10)
print operator.inv(-10)
print operator.invert(10)
print operator.invert(-10)

#乘方 同a**b
print operator.pow(2,3)

#向左移位 同<< 相当于乘以2的相应次方
print operator.lshift(3,2)

#向右移位 同>> 相当于除以2的相应次方 取整
print operator.rshift(3,2)

#按位与 即 a&b
print operator.and_(1,8)
print operator.and_(1,1)
Exemplo n.º 23
0
 def test__invert__(self):
     res = operator.invert(xltypes.Number(2.0))
     self.assertIsInstance(res.value, float)
     self.assertEqual(res.value, -2.0)
Exemplo n.º 24
0
 def test_invert(self):
     import operator
     self.assertEqual(0.2 / SI.meter, operator.invert(5.0 * SI.meter))
     self.assertEqual(-1025, ~SI.kibi)
     self.assertEqual(-2 * SI.meter, ~1 * SI.meter)
Exemplo n.º 25
0
 def bitwise_not_usecase(x):
     return operator.invert(x)
Exemplo n.º 26
0
#print("4 / 2 = operator.div(4,2) = ", operator.div(4,2))

# 除法運算
print("4 / 2 = operator.truediv(4,2) = ", operator.truediv(4,2))

# 除法運算
print("4 // 2 = operator.floordiv(4, 2) = ", operator.floordiv(4, 2))

# 二進位AND運算
print("0x0010 & 0x0011 = operator.and_(0x0010, 0x0011) = ", operator.and_(0x0010, 0x0011))

# 二進位XOR運算
print("0x0010 ^ 0x0011 = operator.xor(0x0010, 0x0011) = ", operator.xor(0x0010, 0x0011))

# 二進位NOT運算
print("~ 0x1000 = operator.invert(0x1000) = ", operator.invert(0x1000))

# 二進位OR運算
print("0x0010 | 0x0011 = operator.or_(0x0010, 0x0011) = ", operator.or_(0x0010, 0x0011))

# 次方運算
print("2 ** 16 = operator.pow(2, 16) = ", operator.pow(2, 16))

# 辨識運算
print("1 is 1 = operator.is_(1,1) = ", operator.is_(1,1))

# 辨識運算
print("1 is not 2 = operator.is_not(1,2) = ", operator.is_not(1,2))

# 以索引指派值
obj = [1,2,3]
Exemplo n.º 27
0
 def __invert__(self):
     return NonStandardInteger(operator.invert(self.val))
Exemplo n.º 28
0
# Python code to demonstrate working of
# and_(), or_(), xor(), invert()

# importing operator module
import operator

# Initializing a and b

a = 1

b = 0

# using and_() to display bitwise and operation
print("The bitwise and of a and b is : ", end="")
print(operator.and_(a, b))

# using or_() to display bitwise or operation
print("The bitwise or of a and b is : ", end="")
print(operator.or_(a, b))

# using xor() to display bitwise exclusive or operation
print("The bitwise xor of a and b is : ", end="")
print(operator.xor(a, b))

# using invert() to invert value of a
operator.invert(a)

# printing modified value
print("The inverted value of a is : ", end="")
print(a)
Exemplo n.º 29
0
print("\nthe modified list after delitem() is : ",end="")
for i in range(0,len(li)):
    print(li[i],end=" ")

print("\nthe 1st and 2nd element of list is : ",end=" ")
print(operator.getitem(li,slice(0, 2)))

#
s1 = "geeksfor"
s2 = "geeks"
print("\nthe concatenated string is : ",end="")
print(operator.concat(s1, s2))

if(operator.contains(s1, s2)):
    print("geeksfor contain geeks")
else:
    print("geeksfor does not contain geeks")


#bitwise
a = 3
b= 4
print("\nthe bitwise and of a and b is : ",end="")
print(operator.and_(a, b))
print("the bitwise or of a and b is : ",end="")
print(operator.or_(a, b))
print("the bitwise xor of a and b is : ",end=" ")
print(operator.xor(a, b))
print("the inverted value of a is : ",end="")
print(operator.invert(a))
Exemplo n.º 30
0
print("\r")
operator.setitem(li,slice(0,7),[7,8,9,0,12,13])
print("new list: ",end=" ")
for i in range(0,len(li)):
    print(li[i],end=" ")
print("\r")
operator.delitem(li,slice(1,2))
print("del list: ",end=" ")
for i in range(0,len(li)):
    print(li[i],end=" ")
print("\r")
print("item: ",end=" ")
print(operator.getitem(li,slice(1,3)))'''

import operator
'''a="viz grover"
b="grover"
c=0
print(operator.concat(a,b))
if(operator.contains(a,b)):
    print("true")
else:
    print("false")'''

a=0
b=1
print(operator.and_(a,b))
print(operator.or_(a,b))
print(operator.xor(a,b))
print(operator.invert(b))
Exemplo n.º 31
0
 def __invert__(self):
   return NonStandardInteger(operator.invert(self.val))
Exemplo n.º 32
0
def invert(a: int):
    assert op.invert(a) == operator.invert(a)
Exemplo n.º 33
0
def test_correctness(yamlstrs):
    """Tests the correctness of the constructors"""
    res = {}

    # Load the resolved yaml strings
    for name, ystr in yamlstrs.items():
        print("Name of yamlstr that will be loaded: ", name)
        if isinstance(name, tuple):
            # Will fail, don't use
            continue
        res[name] = yaml.load(ystr)

    # Test the ParamDim objects
    pdims = res["pdims_only"]["pdims"]

    assert pdims[0].default == 0
    assert pdims[0].values == (1, 2, 3)

    assert pdims[1].default == 0
    assert pdims[1].values == tuple(range(10))

    assert pdims[2].default == 0
    assert pdims[2].values == tuple(np.linspace(1, 2, 3))

    assert pdims[3].default == 0
    assert pdims[3].values == tuple(np.logspace(1, 2, 3))

    assert pdims[4] == 0

    # Test the ParamSpace's
    for psp in res["pspace_only"].values():
        assert isinstance(psp, ParamSpace)

    # Test the utility constructors
    utils = res["utils"]["utils"]
    assert utils["any"] == any([False, 0, True])
    assert utils["all"] == all([True, 5, 0])
    assert utils["abs"] == abs(-1)
    assert utils["int"] == int(1.23)
    assert utils["round"] == round(9.87) == 10
    assert utils["min"] == min([1, 2, 3])
    assert utils["max"] == max([1, 2, 3])
    assert utils["sorted"] == sorted([2, 1, 3])
    assert utils["isorted"] == sorted([2, 1, 3], reverse=True)
    assert utils["sum"] == sum([1, 2, 3])
    assert utils["prod"] == 2 * 3 * 4
    assert utils["add"] == operator.add(*[1, 2])
    assert utils["sub"] == operator.sub(*[2, 1])
    assert utils["mul"] == operator.mul(*[3, 4])
    assert utils["truediv"] == operator.truediv(*[3, 2])
    assert utils["floordiv"] == operator.floordiv(*[3, 2])
    assert utils["mod"] == operator.mod(*[3, 2])
    assert utils["pow"] == 2**4
    assert utils["pow_mod"] == 2**4 % 3 == pow(2, 4, 3)
    assert utils["pow_mod2"] == 2**4 % 3 == pow(2, 4, 3)
    assert utils["not"] == operator.not_(*[True])
    assert utils["and"] == operator.and_(*[True, False])
    assert utils["or"] == operator.or_(*[True, False])
    assert utils["xor"] == operator.xor(*[True, True])
    assert utils["lt"] == operator.lt(*[1, 2])
    assert utils["le"] == operator.le(*[2, 2])
    assert utils["eq"] == operator.eq(*[3, 3])
    assert utils["ne"] == operator.ne(*[3, 1])
    assert utils["ge"] == operator.ge(*[2, 2])
    assert utils["gt"] == operator.gt(*[4, 3])
    assert utils["negate"] == operator.neg(*[1])
    assert utils["invert"] == operator.invert(*[True])
    assert utils["contains"] == operator.contains(*[[1, 2, 3], 4])
    assert utils["concat"] == [1, 2, 3] + [4, 5] + [6, 7, 8]
    assert utils["format1"] == "foo is not bar"
    assert utils["format2"] == "fish: spam"
    assert utils["format3"] == "results: 1.63 ± 0.03"

    assert utils["list1"] == [0, 2, 4, 6, 8]
    assert utils["list2"] == [3, 6, 9, 100]
    assert utils["lin"] == [-1.0, -0.5, 0.0, 0.5, 1.0]
    assert utils["log"] == [10.0, 100.0, 1000.0, 10000.0]
    assert np.isclose(utils["arange"], [0.0, 0.2, 0.4, 0.6, 0.8]).all()

    assert utils["some_map"]
    assert utils["some_other_map"]
    assert utils["merged"] == {
        "foo": {
            "bar": "baz",
            "baz": "bar",
        },
        "spam": "fish",
        "fish": "spam",
    }
    assert utils["merged"] == recursive_update(utils["some_map"],
                                               utils["some_other_map"])
Exemplo n.º 34
0
    key.append(input_var)
    input_var_aux1 = input("Enter position:  ")
    input_var1 = (input_var_aux1 * 1023) / 300
    key1.append(int(input_var1) & 0xFF)
    input_var1 = (input_var1 >> 8)
    key1.append(input_var1)
    #input_var1= input("Enter speed: ")
    input_var1 = 100
    key1.append(int(input_var1) & 0xFF)
    input_var1 = (input_var1 >> 8)
    key1.append(input_var1)
print("ok")
for i in range(0, len(key)):
    checksum = checksum + key[i]
#checksum = key[0] + key[1] + key[2] + key[3] + key[4] + key[5] + key[6] + key[7] +
checksum = operator.invert(checksum)
checksum = checksum & 0x00FF
checksum = checksum - 2
key.append(checksum)
#print checksum
print("ok")
for i in range(0, len(key1)):
    checksum1 = checksum1 + key1[i]
#checksum = key[0] + key[1] + key[2] + key[3] + key[4] + key[5] + key[6] + key[7] +
checksum1 = operator.invert(checksum1)
checksum1 = checksum1 & 0x00FF
checksum1 = checksum1 - 2
key1.append(checksum1)
#print checksum
print("ok")
#print hex(key)
Exemplo n.º 35
0
 def bitwise_not_usecase_binary(x, _unused):
     return operator.invert(x)
Exemplo n.º 36
0
#精确除法
print operator.truediv(9.04, 4)

#绝对值
print operator.abs(-10)

#取反 相当于 -a
print operator.neg(-10)

#取反 相当于 ~a
#~a = (-a)-1
#~10 = -11
#~(-10) = 9
print operator.inv(10)
print operator.inv(-10)
print operator.invert(10)
print operator.invert(-10)

#乘方 同a**b
print operator.pow(2, 3)

#向左移位 同<< 相当于乘以2的相应次方
print operator.lshift(3, 2)

#向右移位 同>> 相当于除以2的相应次方 取整
print operator.rshift(3, 2)

#按位与 即 a&b
print operator.and_(1, 8)
print operator.and_(1, 1)
Exemplo n.º 37
0
class TVMScriptParser(Transformer):
    """Synr AST visitor pass which finally lowers to TIR.

    Notes for Extension
    -------------------
    1. To support a new type of AST node, add a function transform_xxx().
    2. To support new functions, add the function to the appropriate registry:
        We divide allowed function calls in TVM script into 3 categories,
        intrin, scope_handler and special_stmt.
        1. intrin functions are low level functions like mod, load, and
           constants. They correspond to a tir `IRNode`. They must have a
           return value. The user can register intrin functions for the parser to
           use.
        2. scope_handler functions have no return value. They take two
           arguments: the parser and the AST node. scope_handler functions are
           used in with and for statements.
        3. special_stmt functions handle cases that do not have a corresponding
           tir `IRNode`. These functions take the parser and the AST node as
           arguments and may return a value.
        When visiting a Call node, we check the special_stmt registry first. If
        no registered function is found, we then check the intrin registry.
        When visiting With node, we check the with_scope registry.
        When visiting For node, we check the for_scope registry.
    """

    _binop_maker = {
        ast.BuiltinOp.Add: tvm.tir.Add,
        ast.BuiltinOp.Sub: tvm.tir.Sub,
        ast.BuiltinOp.Mul: tvm.tir.Mul,
        ast.BuiltinOp.Div: tvm.tir.Div,
        ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv,
        ast.BuiltinOp.Mod: tvm.tir.FloorMod,
        ast.BuiltinOp.BitOr: lambda lhs, rhs, span: operator.or_(lhs, rhs),
        ast.BuiltinOp.BitAnd: lambda lhs, rhs, span: operator.and_(lhs, rhs),
        ast.BuiltinOp.BitXor: lambda lhs, rhs, span: operator.xor(lhs, rhs),
        ast.BuiltinOp.GT: tvm.tir.GT,
        ast.BuiltinOp.GE: tvm.tir.GE,
        ast.BuiltinOp.LT: tvm.tir.LT,
        ast.BuiltinOp.LE: tvm.tir.LE,
        ast.BuiltinOp.Eq: tvm.tir.EQ,
        ast.BuiltinOp.NotEq: tvm.tir.NE,
        ast.BuiltinOp.And: tvm.tir.And,
        ast.BuiltinOp.Or: tvm.tir.Or,
    }

    _unaryop_maker = {
        ast.BuiltinOp.USub: lambda rhs, span: operator.neg(rhs),
        ast.BuiltinOp.Invert: lambda rhs, span: operator.invert(rhs),
        ast.BuiltinOp.Not: tvm.tir.Not,
    }

    def __init__(self, base_lienno, tir_namespace):
        self.context = None

        self.base_lineno = base_lienno
        self.current_lineno = 0
        self.current_col_offset = 0
        self.tir_namespace = tir_namespace
        self.meta = None

    def init_function_parsing_env(self):
        """Initialize function parsing environment"""
        self.context = ContextMaintainer(self.report_error)  # scope emitter

    def init_meta(self, meta_dict):
        if meta_dict is not None:
            self.meta = tvm.ir.load_json(json.dumps(meta_dict))

    def transform(self, node):
        """Generic transformation for visiting the AST. Dispatches to
        `transform_ClassName` for the appropriate ClassName."""
        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset

        if hasattr(node, "lineno"):
            self.current_lineno = self.base_lineno + node.lineno - 1
        if hasattr(node, "col_offset"):
            self.current_col_offset = node.col_offset

        method = "transform_" + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        transform_res = visitor(node)

        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset

        return transform_res

    def match_tir_namespace(self, identifier: str) -> bool:
        """Check if the namespace is equal to tvm.script.tir"""
        return identifier in self.tir_namespace

    def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]):
        """Report an error occuring at a location.

        This just dispatches to synr's DiagnosticContext.

        Parameters
        ----------
        message : str
            Error message
        span : Union[synr.ast.Span, tvm.ir.Span】
            Location of the error
        """
        if isinstance(span, tvm.ir.Span):
            span = synr_span_from_tvm(span)
        self.error(message, span)

    def parse_body(self, parent):
        """Parse remaining statements in this scope.

        Parameters
        ----------
        parent : synr.ast.Node
            Parent node of this scope. Errors will be reported here.
        """
        body = []
        spans = []
        stmt = parent
        while len(self.context.node_stack[-1]) > 0:
            stmt = self.context.node_stack[-1].pop()
            spans.append(stmt.span)
            res = self.transform(stmt)
            if res is not None:
                body.append(res)
        if len(body) == 0:
            self.report_error(
                "Expected another statement at the end of this block. Perhaps you "
                "used a concise statement and forgot to include a body afterwards.",
                stmt.span,
            )
        else:
            return (tvm.tir.SeqStmt(body,
                                    tvm_span_from_synr(ast.Span.union(spans)))
                    if len(body) > 1 else body[0])

    def parse_arg_list(self, func, node_call):
        """Match the arguments of a function call in the AST to the required
        arguments of the function. This handles positional arguments,
        positional arguments specified by name, keyword arguments, and varargs.

        Parameters
        ----------
        func : Function
            The function that provides the signature

        node_call: ast.Call
            The AST call node that calls into the function.

        Returns
        -------
        arg_list : list
            The parsed positional argument.
        """
        assert isinstance(node_call, ast.Call)
        # collect arguments
        args = [self.transform(arg) for arg in node_call.params]
        kw_args = {
            self.transform(k): self.transform(v)
            for k, v in node_call.keyword_params.items()
        }
        # get the name and parameter list of func
        if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
            func_name, param_list = func.signature()
        else:
            self.report_error(
                "Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, "
                f"but it is {type(func).__name__}",
                node_call.span,
            )
        # check arguments and parameter list and get a list of arguments
        reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
        pos_only, kwargs, varargs = param_list
        internal_args = list()
        for i, arg_name in enumerate(pos_only):
            internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
        for i, arg_info in enumerate(kwargs):
            arg_name, default = arg_info
            internal_args.append(
                reader.get_kwarg(i + 1 + len(pos_only),
                                 arg_name,
                                 default=default))
        if varargs is not None:
            internal_args.extend(
                reader.get_varargs(len(pos_only) + len(kwargs) + 1))
        elif len(args) + len(kw_args) > len(pos_only) + len(kwargs):
            self.report_error(
                "Arguments mismatched. " +
                f"Expected {len(pos_only) + len(kwargs)} args but got " +
                f"{len(args) + len(kw_args)}",
                node_call.span,
            )
        return internal_args

    def parse_type(self, type_node, parent):
        """Parse a type annotation.

        We require the parent object to the type so that we have a place to
        report the error message if the type does not exist.
        """
        if type_node is None:
            self.report_error("A type annotation is required", parent.span)
        res_type = self.transform(type_node)
        return tvm.ir.TupleType(
            []) if res_type is None else res_type.evaluate()

    def generic_visit(self, node):
        """Fallback visitor if node type is not handled. Reports an error."""

        self.report_error(
            type(node).__name__ + " AST node is not supported", node.span)

    def transform_Module(self, node):
        """Module visitor

        Right now, we only support two formats for TVM Script.

        Example
        -------
        1. Generate a PrimFunc (If the code is printed, then it may also contain metadata)
        .. code-block:: python

            import tvm

            @tvm.script
            def A(...):
                ...

            # returns a PrimFunc
            func = A

        2. Generate an IRModule
        .. code-block:: python

            import tvm

            @tvm.script.ir_module
            class MyMod():
                @T.prim_func
                def A(...):
                    ...
                @T.prim_func
                def B(...):
                    ...

                __tvm_meta__ = ...

            # returns an IRModule
            mod = MyMod
        """
        if len(node.funcs) == 1:
            return self.transform(next(iter(node.funcs.values())))
        elif len(node.func) == 0:
            self.report_error(
                "You must supply at least one class or function definition",
                node.span)
        else:
            self.report_error(
                "Only one-function, one-class or function-with-meta source code is allowed",
                ast.Span.union([x.span
                                for x in list(node.funcs.values())[1:]]),
            )

    def transform_Class(self, node):
        """Class definition visitor.

        A class can have multiple function definitions and a single
        :code:`__tvm_meta__` statement. Each class corresponds to a single
        :code:`IRModule`.

        Example
        -------
        .. code-block:: python

            @tvm.script.ir_module
            class MyClass:
                __tvm_meta__ = {}
                def A():
                    T.evaluate(0)
        """
        if len(node.assignments) == 1:
            if not (len(node.assignments[0].lhs) == 1
                    and isinstance(node.assignments[0].lhs[0], ast.Var)
                    and node.assignments[0].lhs[0].id.name == "__tvm_meta__"):
                self.report_error(
                    "The only top level assignments allowed are `__tvm_meta__ = ...`",
                    node.assignments[0].span,
                )
            self.init_meta(MetaUnparser().do_transform(
                node.assignments[0].rhs, self._diagnostic_context))
        elif len(node.assignments) > 1:
            self.report_error(
                "Only a single top level `__tvm_meta__` is allowed",
                ast.Span.union([x.span for x in node.assignments[1:]]),
            )

        return IRModule({
            GlobalVar(name): self.transform(func)
            for name, func in node.funcs.items()
        })

    def transform_Function(self, node):
        """Function definition visitor.

        Each function definition is translated to a single :code:`PrimFunc`.

        There are a couple restrictions on TVM Script functions:
        1. Function arguments must have their types specified.
        2. The body of the function can contain :code:`func_attr` to specify
           attributes of the function (like it's name).
        3. The body of the function can also contain multiple :code:`buffer_bind`s,
           which give shape and dtype information to arguments.
        4. Return statements are implicit.

        Example
        -------
        .. code-block:: python

            @T.prim_func
            def my_function(x: T.handle):  # 1. Argument types
                T.func_attr({"global_symbol": "mmult"})  # 2. Function attributes
                X_1 = tir.buffer_bind(x, [1024, 1024])  # 3. Buffer binding
                T.evaluate(0)  # 4. This function returns 0
        """
        def check_decorator(decorators: List[ast.Expr]) -> bool:
            """Check the decorator is `T.prim_func"""
            if len(decorators) != 1:
                return False
            d: ast.Expr = decorators[0]
            return (isinstance(d, ast.Attr) and isinstance(d.object, ast.Var)
                    and self.match_tir_namespace(d.object.id.name)
                    and d.field.name == "prim_func")

        self.init_function_parsing_env()
        self.context.enter_scope(nodes=node.body.stmts)

        # add parameters of function
        for arg in node.params:
            arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
            self.context.update_symbol(arg.name, arg_var, node)
            self.context.func_params.append(arg_var)

        if not check_decorator(node.decorators):
            self.report_error(
                "All functions should be decorated by `T.prim_func`",
                node.span,
            )

        # New Scope : Implicit root block
        # Each function contains an implicit root block in TensorIR,
        # so here we need a block scope for it. Please note that `enter_block_scope`
        # will not create a block directly but just stores some information.
        # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func),
        # the root block will not be added. The logic to add root block is in `_ffi_api.Complete`
        self.context.enter_block_scope(nodes=node.body.stmts)

        # fetch the body of root block
        body = self.parse_body(node.body)
        # Emit Scope : Implicit root block
        root_info: BlockInfo = self.context.current_block_scope()
        self.context.exit_block_scope()

        # return a tir.PrimFunc
        dict_attr = self.context.func_dict_attr
        ret_type = self.parse_type(node.ret_type,
                                   node) if node.ret_type is not None else None
        func = tvm.tir.PrimFunc(
            self.context.func_params,
            body,
            ret_type,
            buffer_map=self.context.func_buffer_map,
            attrs=tvm.ir.make_node("DictAttrs", **dict_attr)
            if dict_attr else None,
            span=tvm_span_from_synr(node.span),
        )

        # Fix the PrimFunc
        # 1. generate root block if necessary
        # 2. generate surrounding loops for blocks if necessary

        func = call_with_error_reporting(
            self.report_error,
            node.span,
            _ffi_api.Complete,
            func,
            root_info.alloc_buffers,
        )

        self.context.exit_scope()
        return func

    def transform_Assign(self, node):
        """Assign visitor
        AST abstract grammar:
            Assign(expr* targets, expr value, string? type_comment)

        By now 3 patterns of Assign is supported:
            1. special stmts with return value
                1.1 Buffer = T.match_buffer()/T.buffer_decl()
                1.2 Var = T.var()
                1.3 Var = T.env_thread()
            2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
            3. (Store)       Var[PrimExpr] = PrimExpr
            4. with scope handlers with concise scoping and var def
                4.1 var = T.allocate()
        """

        if isinstance(node.rhs, ast.Call):
            # Pattern 1 & Pattern 4
            func = self.transform(node.rhs.func_name)
            if isinstance(func, WithScopeHandler):
                if not func.concise_scope or not func.def_symbol:
                    self.report_error(
                        "with scope handler " + func.signature()[0] +
                        " is not suitable here",
                        node.rhs.span,
                    )
                # Pattern 4
                arg_list = self.parse_arg_list(func, node.rhs)
                func.enter_scope(node, self.context, arg_list,
                                 node.rhs.func_name.span)
                func.body = self.parse_body(node)
                return func.exit_scope(node, self.context, arg_list,
                                       node.rhs.func_name.span)
            elif isinstance(func, SpecialStmt):
                # Pattern 1
                arg_list = self.parse_arg_list(func, node.rhs)
                func.handle(node, self.context, arg_list,
                            node.rhs.func_name.span)
                return self.parse_body(node)
            else:
                value = self.transform(node.rhs)
                if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
                    # This is a little confusing because it only is true when
                    # we have taken this branch. We might need to clarify what
                    # exectly is allowed in Assignments in tvmscript.
                    self.report_error(
                        "Left hand side of assignment must be an unqualified variable",
                        node.span,
                    )
                ast_var = node.lhs[0]
                var = tvm.te.var(
                    ast_var.id.name,
                    self.parse_type(node.ty, ast_var),
                    span=tvm_span_from_synr(ast_var.span),
                )
                self.context.update_symbol(var.name, var, node)
                body = self.parse_body(node)
                self.context.remove_symbol(var.name)
                return tvm.tir.LetStmt(var,
                                       value,
                                       body,
                                       span=tvm_span_from_synr(node.span))

        self.report_error("Unsupported Assign stmt", node.span)

    def transform_SubscriptAssign(self, node):
        """Visitor for statements of the form :code:`x[1] = 2`."""
        symbol = self.transform(node.params[0])
        indexes = self.transform(node.params[1])
        rhs = self.transform(node.params[2])
        rhs_span = tvm_span_from_synr(node.params[2].span)
        if isinstance(symbol, tvm.tir.Buffer):
            # BufferStore
            return tvm.tir.BufferStore(
                symbol,
                tvm.runtime.convert(rhs, span=rhs_span),
                indexes,
                span=tvm_span_from_synr(node.span),
            )
        else:
            if len(indexes) != 1:
                self.report_error(
                    f"Store is only allowed with one index, but {len(indexes)} were provided.",
                    node.params[1].span,
                )
            # Store
            return tvm.tir.Store(
                symbol,
                tvm.runtime.convert(rhs, span=rhs_span),
                indexes[0],
                tvm.runtime.convert(True, span=tvm_span_from_synr(node.span)),
                span=tvm_span_from_synr(node.span),
            )

    def transform_Assert(self, node):
        """Assert visitor

        Pattern corresponds to concise mode of :code:`with T.Assert()`.
        """

        condition = self.transform(node.condition)
        if node.msg is None:
            self.report_error("Assert statements must have an error message.",
                              node.span)
        message = self.transform(node.msg)
        body = self.parse_body(node)
        return tvm.tir.AssertStmt(condition,
                                  tvm.runtime.convert(message),
                                  body,
                                  span=tvm_span_from_synr(node.span))

    def transform_For(self, node):
        """For visitor
        AST abstract grammar:
            For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
        By now 1 pattern of For is supported:
            1. for scope handler
                for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/
                            T.grid()/T.thread_binding()
        """

        if not isinstance(node.rhs, ast.Call):
            self.report_error("The loop iterator should be a function call.",
                              node.rhs.span)
        func = self.transform(node.rhs.func_name)
        if not isinstance(func, ForScopeHandler):
            self.report_error(
                "Only For scope handlers can be used in a for statement.",
                node.rhs.func_name.span)
        # prepare for new for scope
        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
        self.current_lineno = node.span.start_line
        self.current_col_offset = node.span.start_column
        self.context.enter_scope(nodes=node.body.stmts)
        # for scope handler process the scope
        arg_list = self.parse_arg_list(func, node.rhs)
        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
        func.body = self.parse_body(node)
        res = func.exit_scope(node, self.context, arg_list,
                              node.rhs.func_name.span)
        # exit the scope
        self.context.exit_scope()
        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
        return res

    def transform_While(self, node):
        """While visitor
        AST abstract grammar:
            While(expr condition, stmt* body)
        """
        condition = self.transform(node.condition)
        # body
        self.context.enter_scope(nodes=node.body.stmts)
        body = self.parse_body(node)
        self.context.exit_scope()

        return tvm.tir.While(condition,
                             body,
                             span=tvm_span_from_synr(node.span))

    def transform_With(self, node):
        """With visitor
        AST abstract grammar:
            With(withitem* items, stmt* body, string? type_comment)
            withitem = (expr context_expr, expr? optional_vars)
        By now 2 patterns of With is supported:
            1. with scope handler with symbol def
                with T.block(*axes)/T.allocate() as targets:
            2. with scope handler without symbol def
                with T.let()/T.Assert()/T.attr()/T.realize()
        """

        if not isinstance(node.rhs, ast.Call):
            self.report_error(
                "The context expression of a `with` statement should be a function call.",
                node.rhs.span,
            )

        func = self.transform(node.rhs.func_name)

        if not isinstance(func, WithScopeHandler):
            self.report_error(
                f"Function {func} cannot be used in a `with` statement.",
                node.rhs.func_name.span)
        # prepare for new block scope
        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
        self.current_lineno = node.body.span.start_line
        self.current_col_offset = node.body.span.start_column
        self.context.enter_block_scope(nodes=node.body.stmts)
        # with scope handler process the scope
        arg_list = self.parse_arg_list(func, node.rhs)
        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
        func.body = self.parse_body(node)
        res = func.exit_scope(node, self.context, arg_list,
                              node.rhs.func_name.span)
        # exit the scope
        self.context.exit_block_scope()
        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
        return res

    def transform_If(self, node):
        """If visitor
        AST abstract grammar:
            If(expr test, stmt* body, stmt* orelse)
        """

        condition = self.transform(node.condition)
        # then body
        self.context.enter_scope(nodes=node.true.stmts)
        then_body = self.parse_body(node)
        self.context.exit_scope()

        # else body
        if len(node.false.stmts) > 0:
            self.context.enter_scope(nodes=node.false.stmts)
            else_body = self.parse_body(node)
            self.context.exit_scope()
        else:
            else_body = None

        return tvm.tir.IfThenElse(condition,
                                  then_body,
                                  else_body,
                                  span=tvm_span_from_synr(node.span))

    def transform_Call(self, node):
        """Call visitor

        3 different Call patterns are allowed:
            1. Intrin representing a PrimExpr/IterVar
                1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max
                1.2 tir.range/reduce_axis/scan_axis/opaque_axis
            2. tir.Op(dtype, ...)
            3. other callable functions
        """

        if isinstance(node.func_name, ast.Op):
            if node.func_name.name == ast.BuiltinOp.Subscript:
                return self.transform_Subscript(node)
            if node.func_name.name in self._binop_maker:
                lhs = self.transform(node.params[0])
                rhs = self.transform(node.params[1])
                return self._binop_maker[node.func_name.name](
                    lhs, rhs, span=tvm_span_from_synr(node.span))
            if node.func_name.name in self._unaryop_maker:
                rhs = self.transform(node.params[0])
                return self._unaryop_maker[node.func_name.name](
                    rhs, span=tvm_span_from_synr(node.span))
            self.report_error(f"Unsupported operator {node.func_name.name}.",
                              node.func_name.span)
        else:
            func = self.transform(node.func_name)
            if isinstance(func, Intrin) and not func.stmt:
                # pattern 1
                arg_list = self.parse_arg_list(func, node)
                return call_with_error_reporting(
                    self.report_error,
                    node.func_name.span,
                    func.handle,
                    arg_list,
                    node.func_name.span,
                )
            else:
                args = [self.transform(arg) for arg in node.params]
                kw_args = {
                    self.transform(k): self.transform(v)
                    for k, v in node.keyword_params.items()
                }
                if isinstance(func, tvm.tir.op.Op):
                    # pattern 2
                    return tvm.tir.Call(kw_args["dtype"],
                                        func,
                                        args,
                                        span=tvm_span_from_synr(node.span))
                elif callable(func):
                    # pattern 3
                    return func(*args, **kw_args)
                else:
                    self.report_error(
                        f"Function is neither callable nor a tvm.tir.op.Op (it is a {type(func)}).",
                        node.func_name.span,
                    )

    def transform_UnassignedCall(self, node):
        """Visitor for statements that are function calls.

        This handles function calls that appear on thier own line like `tir.realize`.

        Examples
        --------
        .. code-block:: python

            @T.prim_func
            def f():
                A = T.buffer_decl([10, 10])
                T.realize(A[1:2, 1:2], "")  # This is an UnassignedCall
                A[1, 1] = 2  # This is also an UnassignedCall
        """
        # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
        if isinstance(node.call.func_name, ast.Op):
            if node.call.func_name.name != ast.BuiltinOp.SubscriptAssign:
                self.report_error(
                    "Binary and unary operators are not allowed as a statement",
                    node.span)
            else:
                return self.transform_SubscriptAssign(node.call)

        # handle a regular function call
        func = self.transform(node.call.func_name)
        arg_list = self.parse_arg_list(func, node.call)

        if isinstance(func, tir.scope_handler.AssertHandler):
            self.report_error(
                "A standalone `T.Assert` is not allowed. Use `assert condition, message` "
                "instead.",
                node.call.func_name.span,
            )

        if isinstance(func, Intrin):
            if func.stmt:
                return call_with_error_reporting(
                    self.report_error,
                    node.call.func_name.span,
                    func.handle,
                    arg_list,
                    node.call.func_name.span,
                )
            else:
                self.report_error(
                    f"This intrinsic cannot be used as a statement.",
                    node.call.span)
        elif isinstance(func, WithScopeHandler
                        ) and func.concise_scope and not func.def_symbol:
            func.enter_scope(node, self.context, arg_list,
                             node.call.func_name.span)
            func.body = self.parse_body(node)
            return func.exit_scope(node, self.context, arg_list,
                                   node.call.func_name.span)
        elif isinstance(func, SpecialStmt) and not func.def_symbol:
            func.handle(node, self.context, arg_list, node.call.func_name.span)
            return

        self.report_error(
            "Unexpected statement. Expected an assert, an intrinsic, a with statement, or a "
            f"special statement, but got {type(func).__name__}.",
            node.call.func_name.span,
        )

    def transform_Slice(self, node):
        start = self.transform(node.start)
        end = self.transform(node.end)
        if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
            self.report_error("Only step size 1 is supported for slices.",
                              node.step.span)
        return Slice(start, end)

    def transform_Subscript(self, node):
        """Array access visitor.

        By now only 3 types of Subscript are supported:
            1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
               Var[index] Buffer element access()
            2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...]))
            3. Array[index], Buffer element access
        """

        symbol = self.transform(node.params[0])
        if symbol is None:
            self.report_error(
                f"Variable {node.params[0].id.name} is not defined.",
                node.params[0].span)

        indexes = [self.transform(x) for x in node.params[1].values]
        if isinstance(symbol, tvm.tir.expr.Var):
            for index in indexes:
                if not isinstance(index, (tvm.tir.PrimExpr, int)):
                    self.report_error(
                        "Buffer load indexes should be int or PrimExpr, but they are "
                        + type(index),
                        node.span,
                    )
            return tvm.tir.Load("float32",
                                symbol,
                                indexes,
                                True,
                                span=tvm_span_from_synr(node.span))
        elif isinstance(symbol, tvm.tir.Buffer):
            return BufferSlice(symbol,
                               indexes,
                               self.report_error,
                               span=tvm_span_from_synr(node.span))
        elif isinstance(symbol, tvm.container.Array):
            if len(indexes) > 1:
                self.report_error(
                    "Array access should be one-dimension access, but the indices are "
                    + str(indexes),
                    node.span,
                )
            index = indexes[0]
            if not isinstance(index, (int, tvm.tir.expr.IntImm)):
                self.report_error(
                    "Array access index expected int or IntImm, but got " +
                    type(index),
                    node.span,
                )
            if int(index) >= len(symbol):
                self.report_error(
                    f"Array access out of bound, size: {len(symbol)}, got index {index}.",
                    node.span,
                )
            return symbol[int(index)]
        else:
            self.report_error(
                f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
                "buffers are supported.",
                node.params[0].span,
            )

    def transform_Attr(self, node):
        """Visitor for field access of the form `x.y`.

        This visitor is used to lookup function and symbol names. We have two
        cases to handle here:
        1. If we have a statement of the form `tir.something`, then we lookup
           `tir.something` in the `Registry`. If the function is not in the
           registry, then we try to find a `tvm.ir.op.Op` with the same name.
        2. All other names `tvm.something` are lookup up in this current python
           namespace.
        """
        def get_full_attr_name(node: ast.Attr) -> str:
            reverse_field_names = [node.field.name]
            while isinstance(node.object, ast.Attr):
                node = node.object
                reverse_field_names.append(node.field.name)
            if isinstance(node.object, ast.Var):
                reverse_field_names.append(node.object.id.name)
            return ".".join(reversed(reverse_field_names))

        if isinstance(node.object, (ast.Var, ast.Attr)):
            full_attr_name = get_full_attr_name(node)
            attr_object, fields = full_attr_name.split(".", maxsplit=1)
            if self.match_tir_namespace(attr_object):
                func_name = "tir." + fields
                res = Registry.lookup(func_name)
                if res is not None:
                    return res
                try:
                    return tvm.ir.op.Op.get(func_name)
                except TVMError as e:
                    # Check if we got an attribute error
                    if e.args[0].find("AttributeError"):
                        self.report_error(
                            f"Unregistered function `tir.{fields}`.",
                            node.span)
                    else:
                        raise e

        symbol = self.transform(node.object)
        if symbol is None:
            self.report_error("Unsupported Attribute expression.",
                              node.object.span)
        if not hasattr(symbol, node.field.name):
            self.report_error(
                f"Type {type(symbol)} does not have a field called `{node.field.name}`.",
                node.span)
        res = getattr(symbol, node.field.name)
        return res

    def transform_TypeAttr(self, node):
        """Visitor for field access of the form `x.y` for types.

        We have two cases here:
        1. If the type is of the form `T.something`, we look up the type in
           the `tir` namespace in this module.
        2. If the type is of the form `tvm.x.something` then we look up
           `tvm.x.something` in this modules namespace.
        """
        if isinstance(node.object, ast.TypeVar):
            if self.match_tir_namespace(node.object.id.name):
                if not hasattr(tir, node.field.name):
                    self.report_error(
                        f"Invalid type annotation `tir.{node.field.name}`.",
                        node.span)
                return getattr(tir, node.field.name)

        symbol = self.transform(node.object)
        if symbol is None:
            self.report_error("Unsupported Attribute expression",
                              node.object.span)
        if not hasattr(symbol, node.field):
            self.report_error(
                f"Type {type(symbol)} does not have a field called `{node.field}`.",
                node.span)
        res = getattr(symbol, node.field)
        return res

    def transform_DictLiteral(self, node):
        """Dictionary literal visitor.

        Handles dictionary literals of the form `{x:y, z:2}`.
        """

        keys = [self.transform(key) for key in node.keys]
        values = [self.transform(value) for value in node.values]

        return dict(zip(keys, values))

    def transform_Tuple(self, node):
        """Tuple visitor.

        Handles tuples of the form `(x, y, 2)`.
        """

        return tuple(self.transform(element) for element in node.values)

    def transform_ArrayLiteral(self, node):
        """List literal visitor.

        Handles lists of the form `[x, 2, 3]`.
        """

        return [self.transform(element) for element in node.values]

    def transform_Var(self, node):
        """Variable visitor

        Handles variables like `x` in `x = 2`.
        """

        name = node.id.name
        if name == "meta":
            return self.meta
        symbol = Registry.lookup(name)
        if symbol is not None:
            return symbol
        symbol = self.context.lookup_symbol(name)
        if symbol is not None:
            return symbol
        self.report_error(f"Unknown identifier {name}.", node.span)

    def transform_TypeVar(self, node):
        """Type variable visitor.

        Equivalent to `transform_Var` but for types.
        """
        name = node.id.name
        symbol = Registry.lookup(name) or self.context.lookup_symbol(name)
        if symbol is not None:
            return symbol
        self.report_error(f"Unknown identifier {name}.", node.span)

    def transform_Constant(self, node):
        """Constant value visitor.

        Constant values include `None`, `"strings"`, `2` (integers), `4.2`
        (floats), and `true` (booleans).
        """
        return tvm.runtime.convert(node.value,
                                   span=tvm_span_from_synr(node.span))

    def transform_TypeConstant(self, node):
        """Constant value visitor for types.

        See `transform_Constant`.
        """
        return node.value

    def transform_Return(self, node):
        self.report_error(
            "TVM script does not support return statements. Instead the last statement in any "
            "block is implicitly returned.",
            node.span,
        )
Exemplo n.º 38
0
def invert(x): return operator.invert(x)

@builtin.predicate()