def test_dynamic_concat(): """ fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) { if (%i < 10) { let %i = reshape(cast(i, "float32"), newshape=(1, )) let %new_st = concatenate((st, i), axis=0) concat_loop(%i + 1, ) } else { st } } """ # Initial Values. i = relay.var('i', shape=(), dtype='int32') st = relay.var('st', shape=(relay.Any(), 1), dtype='int32') def _cond(i, st): return relay.op.min(relay.op.less(i, int32(10))) def _body(i, st): i_vec = relay.op.reshape(i, (1, 1)) ret = relay.op.concatenate([st, i_vec], axis=0) return i + int32(1), ret loop = while_loop(_cond, [i, st], _body) start = relay.var('start', shape=(), dtype='int32') body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) func = infer_type(func)
def test_recursive_concat_with_wrong_annotation(): """ v0.0.1 fn (%start: int32) { %7 = { let %while_loop = fn (%i: int32, %st: Tensor[(1, 1), int32]) { %0 = less(%i, 10) %1 = min(%0) if (%1) { %2 = add(%i, 1) %3 = reshape(%i, newshape=[1, 1]) %4 = (%st, %3) /* The result of concat should be 1,1 but it is 2, 1. */ %5 = concatenate(%4) %while_loop(%2, %5) } else { (%i, %st) } } %6 = reshape(0, newshape=[1, 1]) %while_loop(%start, %6) } %7.1 } """ # Initial Values. i = relay.var('i', shape=(), dtype='int32') st = relay.var('st', shape=(1, 1), dtype='int32') def _cond(i, st): return relay.op.min(relay.op.less(i, int32(10))) def _body(i, st): i_vec = relay.op.reshape(i, (1, 1)) ret = relay.op.concatenate([st, i_vec], axis=0) return i + int32(1), ret loop = while_loop(_cond, [i, st], _body) start = relay.var('start', shape=(), dtype='int32') body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) try: func = infer_type(func) assert False except Exception as e: assert "in particular dimension 0 conflicts 2 does not match 1" in str( e)