def expected():
        mod = tvm.IRModule({})
        x1 = relay.var("x1", shape=(3, 5))
        y1 = relay.var("y1", shape=(3, 5))
        sb = relay.ScopeBuilder()
        sb.ret(x1 + y1)
        fn1 = relay.Function([x1, y1], sb.get())
        g1 = relay.GlobalVar("g1")
        mod[g1] = fn1

        p0 = relay.var("p0", shape=(3, 5))
        p1 = relay.var("p1", shape=(3, 5))
        p2 = relay.var("p2", shape=(3, 5))
        p3 = relay.var("p3", shape=(3, 5))

        call_fn2 = p2 - p3
        mod["main"] = relay.Function([p0, p1, p2, p3], g1(p0, p1) * call_fn2)

        x0 = relay.var("x0", shape=(3, 5))
        y0 = relay.var("y0", shape=(3, 5))
        z0 = relay.var("z0", shape=(3, 5))

        fn0 = relay.Function([x0, y0, z0], x0 - y0 + z0)
        g0 = relay.GlobalVar("g0")
        mod[g0] = fn0

        return mod
Esempio n. 2
0
def test_custom_op_rel_infer_exception():
    """Tests infer type for custom_op"""

    def custom_log1_rel(arg_types, attrs):
        assert len(arg_types) == 2, "type relation arg number mismatch!"
        return None

    op_name = "custom_log2"
    _op.register(op_name, r"code(cal log of a tensor.)code")
    _op.get(op_name).set_num_inputs(1)
    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
    _op.get(op_name).set_attrs_type_key("DictAttrs")
    # call customized relation functions
    _op.get(op_name).add_type_rel("custom_log2", custom_log1_rel)
    _op.get(op_name).set_support_level(1)
    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
    _op.register_stateful(op_name, False)

    def clog(x):
        return relay.Call(_op.get(op_name), [x])

    tp = relay.TensorType((10, 10), "float32")
    x = relay.var("x", tp)
    sb = relay.ScopeBuilder()
    t1 = sb.let("t1", clog(x))
    t2 = sb.let("t2", relay.add(t1, x))
    sb.ret(t2)
    f = relay.Function([x], sb.get())
    with pytest.raises(tvm.error.TVMError) as cm:
        fchecked = infer_expr(f)
        assert "type relation arg number mismatch" in str(cm.execption)
Esempio n. 3
0
def test_recursion():
    """
    Program:
       def @f(%n: int32, %data: float32) -> float32 {
          if (%n == 0) {
              %data
          } else {
              @f(%n - 1, log(%data))
          }
       }
    """
    sb = relay.ScopeBuilder()
    f = relay.GlobalVar("f")
    ti32 = relay.scalar_type("int32")
    tf32 = relay.scalar_type("float32")
    n = relay.var("n", ti32)
    data = relay.var("data", tf32)

    with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
        sb.ret(data)
    with sb.else_scope():
        sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
    mod = tvm.IRModule()
    mod[f] = relay.Function([n, data], sb.get())
    assert "@f(%1, %2) /* ty=float32 */" in mod.astext()
    assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
 def expected():
     sb = relay.ScopeBuilder()
     x = relay.var("x", t)
     c_folded = (c_data + c_data)
     t3 = sb.let("t3", relay.add(relay.const(c_folded), x))
     sb.ret(t3)
     return relay.Function([x], sb.get())
    def get_mod():
        mod = tvm.IRModule({})

        x = relay.var('x', shape=[], dtype='int32')
        fn0 = relay.Function([x], x)
        fn0 = fn0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
        gx = relay.GlobalVar("gx")
        mod[gx] = fn0

        sum_up = relay.GlobalVar('sum_up')
        i = relay.var('i', shape=[], dtype='int32')
        sb = relay.ScopeBuilder()
        with sb.if_scope(relay.equal(i, relay.const(0, dtype="int32"))):
            sb.ret(i)
        with sb.else_scope():
            one_less = relay.subtract(i, relay.const(1, dtype="int32"))
            global_call = gx(i)
            rec_call = relay.Call(sum_up, [one_less]) + global_call
            sb.ret(relay.add(rec_call, i))
        func = relay.Function([i],
                              sb.get(),
                              ret_type=relay.TensorType([], "int32"))
        func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
        mod[sum_up] = func
        iarg = relay.var("i", shape=[], dtype='int32')
        mod["main"] = relay.Function([iarg], sum_up(iarg))
        return mod
Esempio n. 6
0
def test_custom_op_rel_infer():
    """Tests infer type for custom_op"""

    def custom_log1_rel(arg_types, attrs):
        assert len(arg_types) == 1, "type relation arg number mismatch!"
        if attrs:
            assert isinstance(attrs, DictAttrs)
        inputa_type = arg_types[0]
        return relay.TensorType(inputa_type.shape, inputa_type.dtype)

    op_name = "custom_log1"
    _op.register(op_name, r"code(cal log of a tensor.)code")
    _op.get(op_name).set_num_inputs(1)
    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
    _op.get(op_name).set_attrs_type_key("DictAttrs")
    # call customized relation functions
    _op.get(op_name).add_type_rel("custom_log1", custom_log1_rel)
    _op.get(op_name).set_support_level(1)
    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
    _op.register_stateful(op_name, False)

    def clog(x):
        return relay.Call(_op.get(op_name), [x])

    tp = relay.TensorType((10, 10), "float32")
    x = relay.var("x", tp)
    sb = relay.ScopeBuilder()
    t1 = sb.let("t1", clog(x))
    t2 = sb.let("t2", relay.add(t1, x))
    sb.ret(t2)
    f = relay.Function([x], sb.get())
    fchecked = infer_expr(f)
    assert fchecked.checked_type == relay.FuncType([tp], tp)
Esempio n. 7
0
def test_recursion():
    """
    Program:
       def f(n: i32, data: f32) -> f32 {
          if (n == 0) {
              return data;
          } else {
              return f(n - 1, log(data));
          }
       }
    """
    sb = relay.ScopeBuilder()
    f = relay.GlobalVar("f")
    ti32 = relay.scalar_type("int32")
    tf32 = relay.scalar_type("float32")
    n = relay.var("n", ti32)
    data = relay.var("data", tf32)

    with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
        sb.ret(data)
    with sb.else_scope():
        sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
    mod = relay.Module()
    mod[f] = relay.Function([n, data], sb.get())
    assert "%3 = @f(%1, %2)" in mod.astext()
    assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)
Esempio n. 8
0
def test_monomorphic_let():
    "Program: let x = 1; x"
    sb = relay.ScopeBuilder()
    x = sb.let('x', relay.const(1.0, "float64"))
    sb.ret(x)
    xchecked = relay.ir_pass.infer_type(sb.get())
    assert xchecked.checked_type == relay.scalar_type("float64")
def test_recursive_func():
    mod = tvm.IRModule({})

    x = relay.var('x', shape=[], dtype='int32')
    fn0 = relay.Function([x], x)
    gx = relay.GlobalVar("gx")
    mod[gx] = fn0

    sum_up = relay.GlobalVar('sum_up')
    i = relay.var('i', shape=[], dtype='int32')
    sb = relay.ScopeBuilder()
    with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
        sb.ret(i)
    with sb.else_scope():
        one_less = relay.subtract(i, relay.const(1, dtype='int32'))
        global_call = gx(i)
        rec_call = relay.Call(sum_up, [one_less]) + global_call
        sb.ret(relay.add(rec_call, i))
    func = relay.Function([i],
                          sb.get(),
                          ret_type=relay.TensorType([], 'int32'))
    func = func.with_attr("Compiler", "a")
    mod[sum_up] = func
    iarg = relay.var('i', shape=[], dtype='int32')
    mod["main"] = relay.Function([iarg], sum_up(iarg))
    call_graph = relay.analysis.CallGraph(mod)

    assert call_graph.is_recursive(sum_up)
    assert call_graph.ref_count(sum_up) == 2
    assert call_graph.ref_count(gx) == 1
    assert call_graph.ref_count("main") == 0
Esempio n. 10
0
def test_checkpoint(executor_kind, target, dev):
    inputs = [relay.var("x{}".format(i), shape=(1, )) for i in range(4)]
    output = relay.multiply(relay.add(inputs[0], inputs[1]),
                            relay.add(inputs[2], inputs[3]))
    check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)),
               executor_kind=executor_kind)

    scope = relay.ScopeBuilder()
    out_tuple = scope.let(
        "out_tuple",
        relay.Tuple([
            relay.add(inputs[0], inputs[1]),
            relay.multiply(inputs[2], inputs[3])
        ]),
    )
    scope.ret(
        relay.subtract(
            relay.annotation.checkpoint(relay.TupleGetItem(out_tuple, 0)),
            relay.TupleGetItem(out_tuple, 1),
        ))
    out_single = scope.get()
    check_grad(
        relay.Function(inputs, out_single),
        target_devices=[(target, dev)],
        executor_kind=executor_kind,
    )
Esempio n. 11
0
 def __init__(self, params, quantized_dtypes):
     ExprMutator.__init__(self)
     self.params = set(params)
     self.quantized_dtypes = quantized_dtypes
     self.subtree_params = set()
     self.new_func_params = []
     self.prefix_sb = relay.ScopeBuilder()
     self.prefix_binding_map = {}
Esempio n. 12
0
 def before():
     sb = relay.ScopeBuilder()
     x = relay.var("x", t)
     t1 = sb.let("t1", relay.const(c_data))
     t2 = sb.let("t2", relay.add(t1, t1))
     t3 = sb.let("t3", relay.add(t2, x))
     sb.ret(t3)
     return relay.Function([x], sb.get())
Esempio n. 13
0
 def before(dshape, dtype):
     x = relay.var('x', shape=dshape, dtype=dtype)
     y = relay.var('y', shape=dshape, dtype=dtype)
     sb = relay.ScopeBuilder()
     with sb.if_scope(relay.op.greater(x, y)):
         sb.ret(relay.Function([], x))
     with sb.else_scope():
         sb.ret(relay.Function([], y))
     return relay.Function([x, y], relay.Call(sb.get(), []))
Esempio n. 14
0
 def expected():
     sb = relay.ScopeBuilder()
     x = relay.var("x", t)
     c_folded = c_data + c_data
     t3 = sb.let(
         "t3", annot_expr(relay.add(annot_expr(relay.const(c_folded)), x)))
     sb.ret(t3)
     f = relay.Function([x], sb.get())
     return annot_func(f)
Esempio n. 15
0
def test_monomorphic_let():
    "Program: let %x = 1; %x"
    # TODO(@jroesch): this seems whack.
    sb = relay.ScopeBuilder()
    x = relay.var("x", dtype="float64", shape=())
    x = sb.let(x, relay.const(1.0, "float64"))
    sb.ret(x)
    xchecked = infer_expr(sb.get())
    assert xchecked.checked_type == relay.scalar_type("float64")
Esempio n. 16
0
 def func1():
     sb = relay.ScopeBuilder()
     p0 = relay.var("p0", shape=shape)
     p1 = relay.var("p1", shape=shape)
     a0 = sb.let("a0", relay.add(p0, relay.const(1)))
     a1 = sb.let("a1", relay.add(p1, relay.const(1)))
     a2 = sb.let("a2", relay.add(a0, a1))
     sb.ret(a2)
     return relay.Function([p0, p1], sb.get())
Esempio n. 17
0
 def func2():
     # Alpha conversion is structurally equal
     sb = relay.ScopeBuilder()
     p0 = relay.var("p0", shape=shape)
     p1 = relay.var("p1", shape=shape)
     a1 = sb.let("a1", relay.add(p0, relay.const(1)))
     a0 = sb.let("a0", relay.add(p1, relay.const(1)))
     a2 = sb.let("a2", relay.add(a1, a0))
     sb.ret(a2)
     return relay.Function([p0, p1], sb.get())
Esempio n. 18
0
 def func3():
     # But changing the order of bindings is not structurally equal
     # (even though algebraically equal)
     sb = relay.ScopeBuilder()
     p0 = relay.var("p0", shape=shape)
     p1 = relay.var("p1", shape=shape)
     a1 = sb.let("a1", relay.add(p1, relay.const(1)))
     a0 = sb.let("a0", relay.add(p0, relay.const(1)))
     a2 = sb.let("a2", relay.add(a1, a0))
     sb.ret(a2)
     return relay.Function([p0, p1], sb.get())
Esempio n. 19
0
 def before():
     sb = relay.ScopeBuilder()
     x = relay.var("x", t)
     x.virtual_device_ = tvm.cpu()
     t1 = sb.let("t1", annot_expr(relay.const(c_data)))
     t2 = sb.let("t2", annot_expr(relay.add(t1, t1)))
     t3 = sb.let("t3", annot_expr(relay.add(t2, x)))
     sb.ret(t3)
     f = relay.Function([x], sb.get())
     f.virtual_device_ = tvm.cpu()
     return f
Esempio n. 20
0
 def expected():
     sb = relay.ScopeBuilder()
     x = relay.var("x", t)
     x.virtual_device_ = tvm.cpu()
     c_folded = c_data + c_data
     t3 = sb.let(
         "t3", annot_expr(relay.add(annot_expr(relay.const(c_folded)), x)))
     sb.ret(t3)
     f = relay.Function([x], sb.get())
     f.virtual_device_ = tvm.cpu()
     return f
Esempio n. 21
0
def build_relay_module(batch_size, input_size, hidden_size, time_steps,
                       dense_dim):
    mod = tvm.IRModule()
    mod["lstm_layer"] = lstm_definition(batch_size, input_size, hidden_size,
                                        time_steps)
    mod["linear_layer"] = linear_layer_definition(batch_size, hidden_size,
                                                  dense_dim)
    lstm_var = mod.get_global_var("lstm_layer")
    linear_var = mod.get_global_var("linear_layer")

    # now we build up our main function
    input_var = relay.var("input", shape=(batch_size, time_steps, input_size))
    init_hidden_var = relay.var("init_hidden", shape=(batch_size, hidden_size))
    init_cell_var = relay.var("init_cell", shape=(batch_size, hidden_size))
    i2h_weight_var = relay.var("i2h_weight",
                               shape=(4 * hidden_size, input_size))
    h2h_weight_var = relay.var("h2h_weight",
                               shape=(4 * hidden_size, hidden_size))
    lstm_bias_var = relay.var("lstm_bias", shape=(4 * hidden_size, ))
    linear_weight_var = relay.var("linear_weight",
                                  shape=(dense_dim, hidden_size))
    linear_bias_var = relay.var("linear_bias", shape=(dense_dim, ))

    builder = relay.ScopeBuilder()
    state_var = builder.let("state",
                            relay.Tuple([init_hidden_var, init_cell_var]))
    lstm_res = builder.let(
        "lstm_res",
        lstm_var(
            input_var,
            state_var,
            i2h_weight_var,
            h2h_weight_var,
            lstm_bias_var,
            # the keras model only gave one bias,
            # so set the other to zero
            # (hopefully this is correct)
            relay.zeros_like(lstm_bias_var)))
    final_hidden = builder.let("final_hidden", relay.TupleGetItem(lstm_res, 1))
    # to match PT's semantics, we're undoing the reshape in LSTM :)
    reshape_hidden = builder.let("reshape_hidden",
                                 relay.squeeze(final_hidden, axis=[0]))
    linear_result = builder.let(
        "linear_result",
        linear_var(reshape_hidden, linear_weight_var, linear_bias_var))
    # finally do a softmax
    builder.ret(relay.nn.softmax(linear_result))
    main_func = relay.Function([
        input_var, init_hidden_var, init_cell_var, i2h_weight_var,
        h2h_weight_var, lstm_bias_var, linear_weight_var, linear_bias_var
    ], builder.get())
    mod["main"] = main_func
    return mod
Esempio n. 22
0
def test_decl():
    """Program:
       def f(x : Tensor[(10, 10), f32]) {
           return log(x);
       }
    """
    sb = relay.ScopeBuilder()
    tp = relay.TensorType((10, 10))
    x = relay.var("x", tp)
    f = relay.Function([x], relay.log(x))
    fchecked = relay.ir_pass.infer_type(f)
    assert fchecked.checked_type == relay.FuncType([tp], tp)
 def expected(dshape, dtype):
     x = relay.var('x', shape=dshape, dtype=dtype)
     y = relay.var('y', shape=dshape, dtype=dtype)
     sb = relay.ScopeBuilder()
     p1 = relay.var('p1', shape=dshape, dtype=dtype)
     p2 = relay.var('p2', shape=dshape, dtype=dtype)
     fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2))
     with sb.if_scope(fused_gt(x, y)):
         sb.ret(relay.Function([], x))
     with sb.else_scope():
         sb.ret(relay.Function([], y))
     return relay.Function([x, y], relay.Call(sb.get(), []))
Esempio n. 24
0
def test_compose():
    mod = relay.Module()
    p = Prelude(mod)

    compose = p.compose

    # remove all functions to not have pattern match to pass vm compilation
    # TODO(wweic): remove the hack and implement pattern match
    for v, _ in mod.functions.items():
        if v.name_hint == 'compose':
            continue
        mod[v] = relay.const(0)

    # add_one = fun x -> x + 1
    sb = relay.ScopeBuilder()
    x = relay.var('x', 'float32')
    x1 = sb.let('x1', x)
    xplusone = x1 + relay.const(1.0, 'float32')
    sb.ret(xplusone)
    body = sb.get()
    add_one = relay.GlobalVar("add_one")
    add_one_func = relay.Function([x], body)

    # add_two = compose(add_one, add_one)
    sb = relay.ScopeBuilder()
    y = relay.var('y', 'float32')
    add_two_func = sb.let('add_two', compose(add_one_func, add_one_func))
    add_two_res = add_two_func(y)
    sb.ret(add_two_res)
    add_two_body = sb.get()

    mod[add_one] = add_one_func

    f = relay.Function([y], add_two_body)
    mod["main"] = f

    x_data = np.array(np.random.rand()).astype('float32')
    result = veval(mod)(x_data)

    tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
Esempio n. 25
0
def test_adt_compose():
    mod = relay.Module()
    p = Prelude(mod)

    compose = p.compose

    # add_one = fun x -> x + 1
    sb = relay.ScopeBuilder()
    x = relay.var('x', 'float32')
    x1 = sb.let('x1', x)
    xplusone = x1 + relay.const(1.0, 'float32')
    sb.ret(xplusone)
    body = sb.get()
    add_one = relay.GlobalVar("add_one")
    add_one_func = relay.Function([x], body)

    # add_two = compose(add_one, add_one)
    sb = relay.ScopeBuilder()
    y = relay.var('y', 'float32')
    add_two_func = sb.let('add_two', compose(add_one_func, add_one_func))
    add_two_res = add_two_func(y)
    sb.ret(add_two_res)
    add_two_body = sb.get()

    mod[add_one] = add_one_func

    f = relay.Function([y], add_two_body)
    mod["main"] = f

    vm = create_vm(mod)
    ser = serializer.Serializer(vm)
    code, lib = ser.serialize()
    deser = deserializer.Deserializer(code, lib)
    des_vm = deser.deserialize()

    x_data = np.array(np.random.rand()).astype('float32')
    result = veval(des_vm, x_data)

    tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
Esempio n. 26
0
 def expected(dshape, dtype):
     x = relay.var("x", shape=dshape, dtype=dtype)
     y = relay.var("y", shape=dshape, dtype=dtype)
     sb = relay.ScopeBuilder()
     p1 = relay.var("p1", shape=dshape, dtype=dtype)
     p2 = relay.var("p2", shape=dshape, dtype=dtype)
     fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2))
     fused_gt = fused_gt.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
     with sb.if_scope(fused_gt(x, y)):
         sb.ret(relay.Function([], x))
     with sb.else_scope():
         sb.ret(relay.Function([], y))
     return relay.Function([x, y], relay.Call(sb.get(), []))
Esempio n. 27
0
def test_adt_compose():
    mod = relay.Module()
    p = Prelude(mod)

    compose = p.compose

    # add_one = fun x -> x + 1
    sb = relay.ScopeBuilder()
    x = relay.var('x', 'float32')
    x1 = sb.let('x1', x)
    xplusone = x1 + relay.const(1.0, 'float32')
    sb.ret(xplusone)
    body = sb.get()
    add_one = relay.GlobalVar("add_one")
    add_one_func = relay.Function([x], body)

    # add_two = compose(add_one, add_one)
    sb = relay.ScopeBuilder()
    y = relay.var('y', 'float32')
    add_two_func = sb.let('add_two', compose(add_one_func, add_one_func))
    add_two_res = add_two_func(y)
    sb.ret(add_two_res)
    add_two_body = sb.get()

    mod[add_one] = add_one_func

    f = relay.Function([y], add_two_body)
    mod["main"] = f

    exe = create_exec(mod)
    code, lib = exe.save()
    des_exec = _vm.Executable.load_exec(code, lib)
    des_vm = _vm.VirtualMachine(des_exec)
    des_vm.init(tvm.cpu())

    x_data = np.array(np.random.rand()).astype('float32')
    result = veval(des_vm, x_data)

    tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
Esempio n. 28
0
def fuse_partitions(pre_mod, mid_mod, post_mod):
    """Combine prefix, middle, and suffix modules into a single module.

    The combined module includes an additional `main` that fuses all three
    partitions together.

    Parameters
    ----------
    pre_mod : tvm.IRModule
        Module containing an input quantization function

    mid_mod : tvm.IRModule
        Module containing core of a quantized inference function

    post_mod : tvm.IRModule
        Module containing an output dequantization function

    Returns
    -------
    fused_mod : tvm.IRModule
        Module containing the input quantization, core quantized inference,
        output dequantization, and full quantized inference functions
    """
    pre_func = pre_mod['main']
    mid_func = mid_mod['main']
    post_func = post_mod['main']
    # create a module containing the prefix, middle, and suffix partitions
    fused_mod = tvm.IRModule(functions={
        relay.GlobalVar('quantize_inputs'): pre_func,
        relay.GlobalVar('quantized_main'): mid_func,
        relay.GlobalVar('dequantize_outputs'): post_func,
    })
    # construct a `main` that strings together the partitions, such that its
    # behaviour is equivalent to `main` in an *unpartitioned* module
    scope_builder = relay.ScopeBuilder()
    fused_mod_main_params = [relay.Var(param.name_hint) for param in pre_func.params]
    quantized_inputs = scope_builder.let('quantized_inputs', relay.Call(
        fused_mod.get_global_var('quantize_inputs'),
        fused_mod_main_params
    ))
    quantized_outputs = scope_builder.let('quantized_outputs', relay.Call(
        fused_mod.get_global_var('quantized_main'),
        [relay.TupleGetItem(quantized_inputs, i) for i in range(len(pre_func.ret_type.fields))]
    ))
    dequantized_outputs = scope_builder.let('dequantized_outputs', relay.Call(
        fused_mod.get_global_var('dequantize_outputs'),
        [quantized_outputs]
    ))
    scope_builder.ret(dequantized_outputs)
    fused_mod['main'] = relay.Function(fused_mod_main_params, scope_builder.get())
    return fused_mod
Esempio n. 29
0
def test_let_scalar():
    sb = relay.ScopeBuilder()

    x = relay.var('x', 'float32')
    x1 = sb.let('x1', x)
    xplusone = x1 + relay.const(42.0, 'float32')
    sb.ret(xplusone)
    body = sb.get()

    f = relay.Function([x], body)

    x_data = np.array(np.random.rand()).astype('float32')
    result = veval(f, x_data)
    tvm.testing.assert_allclose(result.asnumpy(), x_data + 42.0)
Esempio n. 30
0
def lstm_definition(batch_size, input_size, hidden_size, time_steps,
                    time_axis=1):
    state_tensor_type = relay.TensorType((batch_size, hidden_size))
    state_tuple_type = relay.TupleType([state_tensor_type, state_tensor_type])

    input_var = relay.var("input", shape=(batch_size, time_steps, input_size))
    state_var = relay.var("state", type_annotation=state_tuple_type)
    i2h_weight_var = relay.var("i2h_weight", shape=(4*hidden_size, input_size))
    h2h_weight_var = relay.var("h2h_weight", shape=(4*hidden_size, hidden_size))
    i2h_bias_var = relay.var("i2h_bias", shape=(4*hidden_size,))
    h2h_bias_var = relay.var("h2h_bias", shape=(4*hidden_size,))

    # in this case, we are ignoring the state outputs
    builder = relay.ScopeBuilder()
    cell_var = builder.let("lstm_cell", relay_lstm_cell(batch_size, input_size, hidden_size))
    splits = builder.let("splits", relay.split(input_var, time_steps, time_axis).astuple())
    last_state = state_var
    seq_outs = []
    for i in range(time_steps):
        squeezed = builder.let(f"squeezed_{i}", relay.squeeze(relay.TupleGetItem(splits, i), axis=[time_axis]))
        cell_out = builder.let(f"cell_out_{i}",
                               cell_var(squeezed, last_state,
                                        i2h_weight_var, h2h_weight_var,
                                        i2h_bias_var, i2h_bias_var))
        new_seq_out = builder.let(f"seq_out_{i}", relay.TupleGetItem(cell_out, 0))
        seq_outs.append(new_seq_out)
        new_hidden = builder.let(f"state_update_{i}", relay.TupleGetItem(cell_out, 1))
        last_state = new_hidden

    stacked = builder.let("stacked", relay.stack(seq_outs, axis=time_axis))
    # finally reshape to match pytorch's semantics (one layer)
    reshape_hidden = builder.let("final_hidden",
                                 relay.reshape(relay.TupleGetItem(last_state, 0),
                                               (1, batch_size, hidden_size)))
    reshape_cell = builder.let("final_cell",
                               relay.reshape(relay.TupleGetItem(last_state, 1),
                                             (1, batch_size, hidden_size)))
    builder.ret(relay.Tuple([stacked, reshape_hidden, reshape_cell]))

    ret_type = relay.TupleType([
        relay.TensorType((batch_size, time_steps, hidden_size)),
        relay.TensorType((1, batch_size, hidden_size)),
        relay.TensorType((1, batch_size, hidden_size))
    ])

    return relay.Function([input_var, state_var, i2h_weight_var, h2h_weight_var,
                           i2h_bias_var, h2h_bias_var],
                          builder.get(),
                          ret_type=ret_type)