def emit_shape_func(self, scope, func, new_args): """Insert the shape function given a primitive function.""" shape_func_ins = [] engine = compile_engine.get() cfunc = engine.lower_shape_func(func, self.target_host) input_states = cfunc.shape_func_param_states is_inputs = [] input_pos = 0 cpu_ctx = nd.cpu(0) for i, (arg, state) in enumerate(zip(new_args, input_states)): state = int(state) # Pass Shapes if state == 2: for j, subexp in enumerate( from_tuple_type(arg.type_annotation, arg)): sh_of = self.visit(self.shape_of(subexp)) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos + j), sh_of)) input_pos += 1 is_inputs.append(0) # Pass Inputs elif state == 1: new_arg = self.visit(arg) ctx = self.get_context(arg) if ctx.device_type != cpu_ctx.device_type: new_arg = self.device_copy(new_arg, ctx, cpu_ctx) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos), new_arg)) input_pos += 1 is_inputs.append(1) else: # TODO(@jroesch): handle 3rd case raise Exception("unsupported shape function input state") out_shapes = [] for i, out in enumerate(cfunc.outputs): tt = ty.TensorType(out.shape, out.dtype) # Put shape func on CPU. This also ensures that everything between # shape_of and shape_func are on CPU. alloc = self.make_static_allocation(scope, tt, cpu_ctx, i) alloc = scope.let("shape_func_out_{0}".format(i), alloc) out_shapes.append(alloc) shape_call = self.shape_func(func, expr.Tuple(shape_func_ins), expr.Tuple(out_shapes), is_inputs) scope.let("shape_func", shape_call) return out_shapes
def test_tuple_object(): x = relay.var('x', type_annotation=relay.ty.TupleType([ relay.ty.TensorType((), 'int32'), relay.ty.TensorType((), 'int32') ])) fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) mod = tvm.IRModule.from_expr(fn) exe = relay.create_executor(kind="vm", mod=mod, ctx=nd.cpu(), target="llvm") f = exe.evaluate() value_tuple = _container.tuple_object( [nd.array(np.array(11)), nd.array(np.array(12))]) # pass an ADT object to evaluate out = f(value_tuple) tvm.testing.assert_allclose(out.asnumpy(), np.array(11))