def validate(shape, value, dtype): def before_left(x, elem_op, full): return elem_op(full, x) def after_left(x, elem_op, value): return elem_op(relay.const(value, dtype), x) def before_right(x, elem_op, full): return elem_op(x, full) def after_right(x, elem_op, value): return elem_op(x, relay.const(value, dtype)) x = relay.var("x", shape=shape, dtype=dtype) elem_ops = [relay.add, relay.multiply, relay.subtract, relay.divide] full_ops = [] if value == 0: full_ops.append(relay.zeros(shape, dtype)) full_ops.append(relay.zeros_like(x)) if value == 1: full_ops.append(relay.ones(shape, dtype)) full_ops.append(relay.ones_like(x)) else: full_ops.append(relay.full(relay.const(value, dtype), shape)) full_ops.append(relay.full_like(x, relay.const(value, dtype))) for op in elem_ops: for full in full_ops: z = before_left(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(after_left(x, op, value), transform.InferType()) assert tvm.ir.structural_equal(zz, after) z = before_right(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(after_right(x, op, value), transform.InferType()) assert tvm.ir.structural_equal(zz, after) # Test the case in which x is broadcast to full's shape full_ops = [] if value == 0: full_ops.append(relay.zeros(shape * 2, dtype)) if value == 1: full_ops.append(relay.ones(shape * 2, dtype)) else: full_ops.append(relay.full(relay.const(value, dtype), shape * 2)) for op in elem_ops: for full in full_ops: z = before_left(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(before_left(x, op, full), transform.InferType()) assert tvm.ir.structural_equal(zz, after) z = before_right(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(before_right(x, op, full), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_loop_free_var(): x = relay.var("x", shape=(), dtype="int32") i = relay.var("i", shape=(), dtype="int32") s = relay.var("s", shape=(), dtype="int32") def cond(i, _): return i < relay.const(10, dtype="int32") def body_no_free_var(i, acc): incr = relay.const(1, "int32") return i + incr, acc + i def body_with_free_var(i, acc): incr = relay.const(1, "int32") return i + incr, acc + x for args, body, expected in zip([[], [1]], [body_no_free_var, body_with_free_var], [45, 10]): loop = while_loop(cond, [i, s], body) tup = loop(relay.const(0, dtype="int32"), relay.zeros(shape=(), dtype="int32")) ret = relay.TupleGetItem(tup, 1) mod = tvm.IRModule() mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret) check_result(args, expected, mod=mod)
def after(): data = relay.var("data", shape=(1, 32)) eq1 = relay.var("e1", shape=[], dtype="float32") eq2 = relay.var("e2", shape=[], dtype="float32") cb_1 = relay.annotation.compiler_begin(eq1, target) cb_2 = relay.annotation.compiler_begin(eq2, target) equality_condition = relay.equal(cb_1, cb_2) ce_1 = relay.annotation.compiler_end(equality_condition, target) # if condition true_branch = relay.zeros(shape=(1, 32), dtype="float32") # else condition cb_3 = relay.annotation.compiler_begin(data, target) false_branch = relay.sigmoid(cb_3) ce_2 = relay.annotation.compiler_end(false_branch, target) if_condition = relay.If(ce_1, true_branch, ce_2) cb_4 = relay.annotation.compiler_begin(if_condition, target) erf_out = relay.erf(cb_4) ce_3 = relay.annotation.compiler_end(erf_out, target) func = relay.Function([data, eq1, eq2], ce_3) mod = tvm.IRModule.from_expr(func) return mod
def test_recursive(): mod = tvm.IRModule() x = relay.var("x", shape=(2, )) i = relay.var("i", shape=(), dtype="int32") s = relay.var("s", shape=(2, )) cond = i < relay.const(10, dtype="int32") loop = relay.var("while_loop") sb = relay.scope_builder.ScopeBuilder() with sb.if_scope(cond): ii = i + relay.const(1, dtype="int32") ss = s + x sb.ret(loop(ii, ss)) with sb.else_scope(): sb.ret(s) func = relay.Function([i, s], sb.get()) ret = relay.Let( loop, func, loop(relay.const(0, dtype="int32"), relay.zeros(shape=(2, ), dtype="float32"))) mod["main"] = relay.Function([x], ret) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 2
def test_concretize_zeros_like(): dtype = "int32" shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) expr = relay.zeros_like(shape_like) expected = run_infer_type(relay.zeros((3, 4, 5), dtype)) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) assert tvm.ir.structural_equal(actual, expected)
def test_closure(): mod = tvm.IRModule() x = relay.var('x', shape=(2,)) y = relay.var('y', shape=(2,)) inner_func = relay.Function([x], x + y) outer_func = relay.Function([y], inner_func) clo = outer_func(relay.ones(shape=(2,), dtype="float32")) mod["main"] = relay.Function([], relay.Call(clo, [relay.zeros(shape=(2,), dtype="float32")])) new_mod = transform.LambdaLift()(mod) assert len(new_mod.functions) == 3
def get_func_with_control_flow(): data = relay.var("data", shape=(1, 3, 224, 224)) weight = relay.var("weight", shape=(32, 3, 3, 3)) eq1 = relay.var("e1", shape=[], dtype="float32") eq2 = relay.var("e2", shape=[], dtype="float32") eq = relay.equal(eq1, eq2) true_branch = relay.zeros(shape=(1, 32, 222, 222), dtype="float32") false_branch = relay.nn.conv2d(data, weight, kernel_size=(3, 3), channels=32) ife = relay.If(eq, true_branch, false_branch) out = relay.erf(ife) return relay.Function([data, weight, eq1, eq2], out)
def before(): data = relay.var("data", shape=(1, 32)) eq1 = relay.var("e1", shape=[], dtype="float32") eq2 = relay.var("e2", shape=[], dtype="float32") eq = relay.equal(eq1, eq2) true_branch = relay.zeros(shape=(1, 32), dtype="float32") false_branch = relay.sigmoid(data) ife = relay.If(eq, true_branch, false_branch) out = relay.erf(ife) func = relay.Function([data, eq1, eq2], out) mod = tvm.IRModule.from_expr(func) return mod
def get_net(iterations, num_hidden, batch_size=1, dtype="float32"): '''Constructs an unrolled RNN with LSTM cells''' input_type = relay.TensorType((batch_size, num_hidden), dtype) weight_type = relay.TensorType((4 * num_hidden, num_hidden), dtype) bias_type = relay.TensorType((4 * num_hidden, ), dtype) state_type = relay.TupleType([input_type, input_type]) cell_type = relay.TupleType([input_type, state_type]) builder = relay.ScopeBuilder() zeros = builder.let(("zeros", input_type), relay.zeros((batch_size, num_hidden), dtype)) init_states = builder.let(("init_states", state_type), relay.Tuple([zeros, zeros])) states = init_states out = None for i in range(iterations): inputs = relay.Var("data", input_type) i2h_weight = relay.Var("i2h_%s_weight" % i, weight_type) i2h_bias = relay.Var("i2h_%i_bias" % i, bias_type) h2h_weight = relay.Var("h2h_%s_weight" % i, weight_type) h2h_bias = relay.Var("h2h_%s_bias" % i, bias_type) cell_fn = lstm_cell(num_hidden, batch_size, dtype, "lstm_%s" % i) call = builder.let( ("call_%s" % i, cell_type), relay.Call( cell_fn, [inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias])) new_out = builder.let(("out_%s" % i, input_type), relay.TupleGetItem(call, 0)) new_states = builder.let(("states_%s" % i, state_type), relay.TupleGetItem(call, 1)) states = new_states out = new_out builder.ret(out) body = builder.get() args = relay.analysis.free_vars(body) return relay.Function(args, body, input_type)
def get_net(iterations, num_hidden, batch_size=1, dtype="float32"): '''Constructs an unrolled RNN with LSTM cells''' input_type = relay.TensorType((batch_size, num_hidden), dtype) weight_type = relay.TensorType((4*num_hidden, num_hidden), dtype) bias_type = relay.TensorType((4*num_hidden,), dtype) state_type = relay.TupleType([input_type, input_type]) cell_type = relay.TupleType([input_type, state_type]) builder = relay.ScopeBuilder() zeros = builder.let(("zeros", input_type), relay.zeros((batch_size, num_hidden), dtype)) init_states = builder.let(("init_states", state_type), relay.Tuple([zeros, zeros])) states = init_states out = None for i in range(iterations): inputs = relay.Var("data", input_type) i2h_weight = relay.Var("i2h_%s_weight" % i, weight_type) i2h_bias = relay.Var("i2h_%i_bias" % i, bias_type) h2h_weight = relay.Var("h2h_%s_weight" % i, weight_type) h2h_bias = relay.Var("h2h_%s_bias" % i, bias_type) cell_fn = lstm_cell(num_hidden, batch_size, dtype, "lstm_%s" % i) call = builder.let(("call_%s" % i, cell_type), relay.Call(cell_fn, [inputs, states, i2h_weight, i2h_bias, h2h_weight, h2h_bias])) new_out = builder.let(("out_%s" % i, input_type), relay.TupleGetItem(call, 0)) new_states = builder.let(("states_%s" % i, state_type), relay.TupleGetItem(call, 1)) states = new_states out = new_out builder.ret(out) body = builder.get() args = relay.ir_pass.free_vars(body) return relay.Function(args, body, input_type)
def test_zeros(): """Simple test using "zeros" op""" mod = tvm.IRModule() shape = (10, 10) dtype = "float32" t = relay.TensorType(shape, dtype) x = relay.var("x", t) y = relay.Function([x], x + relay.zeros(shape, dtype)) mod["main"] = y mod = transform.InferType()(mod) mod = transform.LazyGradientInit()(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], t) x = rand(dtype, *shape) y = create_executor(mod=mod).evaluate(y)(x) assert_allclose(y.numpy(), x.numpy())
def after(): func = relay.Function([], relay.zeros(shape=(0), dtype="float32")) mod = tvm.IRModule.from_expr(func) return mod
def flexible_dispatch(mod, buckets, axis=0, auto_pad=False, pad_value=0, input_indices=None, affects_output=True): """ Enable inference of multiple shaped inputs in one module. This transformation adds a handler around a module that checks input shapes and dispatches to a subgraph specialized to handle the specific shapes of that input. If no exactly matching subgraph is available, the input will be run using full dynamism. For best performance, specify all the sizes the module will be likely to see using the buckets argument. By default, this function will dispatch shapes that exactly match one of the buckets to a corresponding subgraph. All non-matching shapes use the same fully dynamic fallback. This can be detrimental to performance for those non-matching shapes. Setting auto_pad to True causes this function to round-up the shape of non-matching inputs to the closest bucket. This allows them to use the tuned kernels of bucket shapes which can improve performance. Functions that have multiple inputs sharing a dynamic axis, which is common for batch size or sequence length dynamism, are supported through the input_indices argument. Many types of dynamism such as batching affect both the input and output shape, however this is not always the case. If the output shape is independent of the input, the affects_output argument of this function must be set to False. Parameters ---------- buckets: list[int] The sizes of the input dimension that should be explicitly handled. Each value in buckets will have a corresponding subgraph constructed to handle it. axis: int The dimension of the input that should be made flexible. This will most often be used for the batch dimension. auto_pad: Optional[bool] If True, then padding will be inserted to values that don't match one of the provided buckets. pad_value: Optional[float] When auto_pad is true, padding will be done with this value. input_indices: Optional[List[int]] Which inputs should be dispatched dynamically, provided by index. All inputs must share the same dynamic axis. affects_output: Optional[bool] Whether the change in input shape has a corresponding effect on the output shape. Batching for example effects both the input and output whereas changing sequence length in an NLP model typically does not. Returns ------- mod : IRModule The new module wrapped with a flexible shape dispatch handler. """ main_fn = mod["main"] # Default to single input if not specified. if input_indices is None: input_indices = [0] # Extract all input data and create a new dynamic variable for each. data = [] dyn_data = [] for i in input_indices: data.append(main_fn.params[i]) dyn_shape = override_shape(data[i].type_annotation, axis, relay.Any()) dyn_data.append(relay.Var(data[i].name_hint, type_annotation=dyn_shape)) # Extract the dynamic shape value from one of the inputs. rt_sh = relay.op.shape_of(dyn_data[0]) flex_value = relay.op.take(rt_sh, relay.const(axis)) if_exprs = [] for i, bucket in enumerate(buckets): input_data = dyn_data check_dim = flex_value # Apply automatic padding if specified. if auto_pad: input_data = [] # Construct padding expression for inputs. for j, inp in enumerate(dyn_data): pad_width = relay.const(bucket) - flex_value rank = len(data[j].type_annotation.shape) pads = relay.zeros([rank, 2], "int32") pads = relay.scatter_nd(pads, relay.const([axis, 1]), pad_width) padded_value = relay.nn.pad(inp, pads, pad_value) # Determine if this is the proper bucket to pad to. Do this by checking if the # input shape is between this bucket and the previous. if i == 0: padded_value = relay.If( relay.op.less_equal(flex_value, relay.const(bucket)), padded_value, inp) else: padded_value = relay.If( relay.op.logical_and( relay.op.less_equal(flex_value, relay.const(bucket)), relay.op.greater(flex_value, relay.const(buckets[i - 1])), ), padded_value, inp, ) # Update input value and test dimension to reflect possible padding. input_data.append(padded_value) # Grab the new possibly padded shape for checking bucket size. check_dim = relay.op.take(relay.op.shape_of(input_data[0]), relay.const(axis)) # Create a specialized subgraph for the current bucket. spec_call, spec_ty = specialize_body(mod, main_fn, axis, bucket, input_indices=input_indices, affects_output=affects_output) # Apply hard casting to shape to create statically typed graphs. spec_data = [] for j, inp in enumerate(input_data): spec_data.append(relay.op.reshape(inp, spec_ty[j].shape)) # Create a dispatch statement for the current specialized graph. call_args = list(main_fn.params) for j, inp in enumerate(input_indices): call_args[inp] = spec_data[j] new_call = spec_call(*call_args) # Remove meaningless padded outputs if applicable. if auto_pad and affects_output: new_call = relay.take( new_call, relay.arange(start=relay.const(0), stop=flex_value, dtype="int32"), axis=axis, ) # Add this new case to the dispatch handler. if_exprs.append((relay.op.equal(check_dim, relay.const(bucket)), new_call)) # Create a subgraph to handle all other shapes. default_dyn_call, _ = specialize_body(mod, main_fn, axis, relay.Any(), input_indices=input_indices, affects_output=affects_output) call_args = list(main_fn.params) for j, inp in enumerate(input_indices): call_args[inp] = dyn_data[j] new_body = default_dyn_call(*call_args) # Create an If chain to dispatch shapes to the appropriate specialized subgraph. for cond, true_branch in if_exprs: new_body = relay.If(cond, true_branch, new_body) # Assign new parameters to the function. new_params = list(main_fn.params) for j, inp in enumerate(input_indices): new_params[inp] = dyn_data[j] # Update the output shape to be dynamic if needed. if affects_output: dyn_ret_type = override_shape(main_fn.ret_type, axis, relay.Any()) else: dyn_ret_type = main_fn.ret_type # Assign the handler as the new entrypoint in the module. new_main = relay.Function(new_params, new_body, dyn_ret_type, main_fn.type_params, main_fn.attrs) mod["main"] = new_main # Do type inference to make sure everything worked. mod = relay.transform.InferType()(mod) return mod