def expected(): a = relay.var('a', shape=(10, 10)) b = relay.var('b', shape=(10, 10)) c = relay.var('c', shape=(10, 10)) # add_sub_mul function in_1 = relay.var('in_1', shape=(10, 10)) in_2 = relay.var('in_2', shape=(10, 10)) add_node = relay.add(in_1, in_2) sub_node = relay.subtract(in_1, in_2) mul_node = relay.multiply(add_node, sub_node) add_sub_mul = relay.Function([in_1, in_2], mul_node) add_sub_mul = add_sub_mul.set_attribute("Primitive", tir.IntImm("int32", 1)) add_sub_mul = add_sub_mul.set_attribute("Composite", tir.StringImm("add_sub_mul")) # add_sub_mul1 function in_3 = relay.var('in_3', shape=(10, 10)) in_4 = relay.var('in_4', shape=(10, 10)) add_node_1 = relay.add(in_3, in_4) sub_node_1 = relay.subtract(in_3, in_4) mul_node_1 = relay.multiply(add_node_1, sub_node_1) add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive", tir.IntImm("int32", 1)) add_sub_mul_1 = add_sub_mul_1.set_attribute( "Composite", tir.StringImm("add_sub_mul")) # merged function m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) m_add_sub_mul_2 = relay.Call(add_sub_mul_1, [c, m_add_sub_mul_1]) r = relay.nn.relu(m_add_sub_mul_2) return relay.Function([a, b, c], r)
def _create_schedule(): func = matmul sch = tir.Schedule(func, debug_mask="all") c = sch.get_block("C") c_local = sch.cache_write(c, 0, "local") i, j, k = sch.get_loops(c) # pylint: disable=invalid-name i0, i1, i2, i3, i4 = sch.split(i, factors=[None, 1, 16, 32, 1]) # outer: 1 j0, j1, j2, j3, j4 = sch.split(j, factors=[None, 4, 1, 1, 16]) # outer: 8 k0, k1, k2 = sch.split(k, factors=[None, 1, 2]) # outer: 256 # pylint: enable=invalid-name # fmt: off sch.reorder( i0, j0, # S i1, j1, # S i2, j2, # S k0, # R k1, # R i3, j3, # S k2, # R i4, j4, # S ) # fmt: on # thread binding i0_j0 = sch.fuse(i0, j0) i1_j1 = sch.fuse(i1, j1) i2_j2 = sch.fuse(i2, j2) sch.bind(i0_j0, "blockIdx.x") sch.bind(i1_j1, "vthread.x") sch.bind(i2_j2, "threadIdx.x") # fusion sch.reverse_compute_at(c_local, i2_j2) # cache read 'A' a_shared = sch.cache_read(c, 1, "shared") sch.compute_at(a_shared, k0) _, _, _, _, a_i, a_j = sch.get_loops(a_shared) a_ij = sch.fuse(a_i, a_j) _, a_j = sch.split(a_ij, factors=[None, 16]) # outer: 64 sch.bind(a_j, "threadIdx.x") # cache read 'B' b_shared = sch.cache_read(c, 2, "shared") sch.compute_at(b_shared, k0) _, _, _, _, b_i, b_j = sch.get_loops(b_shared) b_ij = sch.fuse(b_i, b_j) _, b_j = sch.split(b_ij, factors=[None, 16]) # outer: 8 sch.bind(b_j, "threadIdx.x") # auto unroll sch.annotate(i0_j0, "pragma_auto_unroll_max_step", tir.IntImm("int32", 1024)) sch.annotate(i0_j0, "pragma_unroll_explicit", tir.IntImm("int32", 1)) return sch
def tir_imm(obj, dtype=None) -> tir.PrimExpr: if isinstance(obj, tir.PrimExpr): return obj if isinstance(obj, bool): return tir.IntImm(dtype=dtype or 'bool', value=obj) if isinstance(obj, float): return tir.FloatImm(dtype=dtype or 'float32', value=obj) if isinstance(obj, int): return tir.IntImm(dtype=dtype or 'int32', value=obj) if isinstance(obj, str): return tir.StringImm(obj) assert False
def test_eq_ops(): a = tir.IntImm("int8", 1) with pytest.raises(ValueError): assert a != None with pytest.raises(ValueError): assert not a == None b = tir.StringImm("abc") assert b != None assert not b == None
def parse_for(self, node, parent): with self._for_loop_vars(node) as (iter_var, c_var, extent_var, lower, upper, step, for_type): extent = tir.FloorDiv(tir.Sub(upper, lower), step) return tir.LetStmt( extent_var, extent, tir.For( iter_var, tir.IntImm('int32', 0), extent_var, for_type, tir.LetStmt(c_var, tir.Add(tir.Mul(iter_var, step), lower), self.parse(node.body(), node))))
def expected(): data = relay.var('data', shape=(1, 512, 28, 28)) kernel = relay.var('kernel', shape=(256, 512, 1, 1)) bias = relay.var('bias', shape=(256, )) a = relay.var('a', shape=(1, 256, 28, 28)) b = relay.var('b', shape=(1, 256, 28, 28)) # conv_bias_relu function in_1 = relay.var('in_1', shape=(1, 512, 28, 28)) in_2 = relay.var('in_2', shape=(256, 512, 1, 1)) in_3 = relay.var('in_3', shape=(256, )) conv_node = relay.nn.conv2d(in_1, in_2, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1)) bias_node = relay.nn.bias_add(conv_node, in_3) r = relay.nn.relu(bias_node) conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) conv_bias_add_relu = conv_bias_add_relu.set_attribute( "Primitive", tir.IntImm("int32", 1)) conv_bias_add_relu = conv_bias_add_relu.set_attribute( "Composite", tir.StringImm("conv2d_bias_relu")) # add_relu function in_4 = relay.var('in_4', shape=(1, 256, 28, 28)) in_5 = relay.var('in_5', shape=(1, 256, 28, 28)) add_node = relay.add(in_4, in_5) r = relay.nn.relu(add_node) add_relu = relay.Function([in_4, in_5], r) add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) # merged function conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) add_relu_1 = relay.Call(add_relu, [conv_bias_add_relu_1, a]) r = relay.multiply(add_relu_1, b) return relay.Function([data, kernel, bias, a, b], r)
def after(): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) x = relay.var('x') y = relay.var('y') branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1)) func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul")) call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1)) func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul")) call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) return relay.Function([input_1, input_2], out)
def after_A(): inputs = [ relay.var('input_' + str(i), shape=(10, 10)) for i in range(4) ] x = relay.var('x') y = relay.var('y') add_relu_1 = relay.add(x, y) add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1) add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 1)) add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) x1 = relay.var('x1') y1 = relay.var('y1') add_relu_2 = relay.add(x1, y1) add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2) add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 1)) add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) x2 = relay.var('x2') y2 = relay.var('y2') add = relay.add(x2, y2) sub = relay.subtract(x2, y2) add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.Function([x2, y2], add_sub_mul) add_sub_mul = add_sub_mul.set_attribute('Primitive', tir.IntImm('int32', 1)) add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul')) add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) return relay.Function(inputs, add_sub_mul_call)
def after_A_priority(composite_name): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) x = relay.var('x') y = relay.var('y') out = relay.add(x, y) out = relay.abs(out) out = relay.nn.relu(out) merged_func = relay.Function([x, y], out) merged_func = merged_func.set_attribute('Primitive', tir.IntImm('int32', 1)) merged_func = merged_func.set_attribute('Composite', tir.StringImm(composite_name)) ret = relay.Call(merged_func, [input_1, input_2]) return relay.Function([input_1, input_2], ret)
def expected(): a = relay.var('a', shape=(10, 10)) b = relay.var('b', shape=(10, 10)) # add_relu function in_1 = relay.var('in_1', shape=(10, 10)) in_2 = relay.var('in_2', shape=(10, 10)) add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) # merged function r = relay.Call(add_relu, [a, b]) return relay.Function([a, b], r)
def expected(): a = relay.var('a', shape=(10, 10)) b = relay.var('b', shape=(10, 10)) # add_relu_add function in_1 = relay.var('in_1', shape=(10, 10)) in_2 = relay.var('in_2', shape=(10, 10)) add_node = relay.add(in_1, in_2) add_node_1 = relay.add(in_1, add_node) add_node_2 = relay.add(add_node_1, add_node) add_add_add = relay.Function([in_1, in_2], add_node_2) add_add_add = add_add_add.set_attribute("Primitive", tir.IntImm("int32", 1)) add_add_add = add_add_add.set_attribute("Composite", tir.StringImm("add_add_add")) # merged function sub_node = relay.subtract(a, b) call = relay.Call(add_add_add, [sub_node, b]) return relay.Function([a, b], call)
def after_B(): inputs = [ relay.var('input_' + str(i), shape=(10, 10)) for i in range(8) ] add_relu_calls = [] for i in range(4): x = relay.var('x' + str(i)) y = relay.var('x' + str(i)) add_relu = relay.add(x, y) add_relu = relay.nn.relu(add_relu) add_relu = relay.Function([x, y], add_relu) add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 1)) add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_call = relay.Call(add_relu, [inputs[i * 2], inputs[i * 2 + 1]]) add_relu_calls.append(add_relu_call) add = relay.add(add_relu_calls[0], add_relu_calls[1]) sub = relay.subtract(add_relu_calls[2], add_relu_calls[3]) out = relay.multiply(add, sub) return relay.Function(inputs, out)
def parse_op_minus(self, expr, parent): if expr.n_arg() == 1: return tir.Sub(tir.IntImm('int32', 0), self.parse(expr.arg(0), expr)) return tir.Sub(self.parse(expr.arg(0), expr), self.parse(expr.arg(1), expr))
def parse_int(self, expr, parent): return tir.IntImm('int32', int(expr.to_C_str()))
def apply(val=None): if val: return val BoolLit.last_random = random.random() >= 0.5 return tir.IntImm('bool', BoolLit.last_random)
def apply(val=None): if val: return val IntLit.last_random = random.randint(-100, 100) return tir.IntImm('int32', IntLit.last_random)