示例#1
0
    def __init__(self, ch, env):
        src = clip_head(inspect.getsource(ch.forward))
        dprint(src)
        self.ast = gast.ast_to_gast(ast.parse(src)).body[0]

        self.call = User_Defined_Func_In_Link(ch, ch.forward).call

        # 以下、 最初の外からのためのやつ
        # code.InteractiveConsole({'v': self.ast}).interact()
        self.forward_arglen = len(self.ast.args.args) - 1

        # ここで、初期化したやつを上書きしてやる必要が出てくる
        # あとでchainerで実行するために回復しないといけないので、
        # restore_funcs に復元すべきものを追加している
        self.inits = []

        for s, v in ch.namedparams():
            s = s[1:]
            if s.find('/') != -1:
                continue
            t = helper.make_tensor_value_info('/' + s, TensorProto.FLOAT,
                                              list(v.shape))
            self.inits.append(t)
            mv = getattr(ch, s)
            setattr(ch, s, t)
            env.restore_funcs.append(lambda: setattr(ch, s, mv))

        # TODO(satos) Yieldをコンパイルできるとこれを消せる
        mv = getattr(ch, 'children')
        setattr(ch, 'children', Func(lambda _, __, ___: mv()))
        env.restore_funcs.append(lambda: setattr(ch, 'children', mv))
示例#2
0
def eval_ast(nast, env):
    for k, v in env.get_var_dict().items():
        assert not isinstance(v, onnx.ValueInfoProto), '%s %s' % (k, v)

    global _eval_ast_depth
    if not isinstance(nast, list):
        dprint('-' * _eval_ast_depth, gast.dump(nast), env.get_var_dict().keys())

    _eval_ast_depth += 1
    r = eval_ast_impl(nast, env)
    _eval_ast_depth -= 1
    return _value(r)
示例#3
0
def eval_call(nast, env):
    fn = eval_ast(nast.func, env)
    if not fn.is_py:
        raise TypeError('Expected a callable: %s' % fn.value)
    fn = fn.value

    # TODO(hamaji): Merge this logic with is_print_logging. Also,
    # maybe it's better to try emitting ChainerPrint.
    if fn in (logging.debug, logging.info,
              logging.warn, logging.warning, logging.error):
        return None

    args = []
    for ag in nast.args:
        if isinstance(ag, gast.Starred):
            args += list(eval_ast(ag.value, env))
        else:
            args.append(eval_ast(ag, env))

    keywords = dict(
        map(lambda x: (x.arg, eval_ast(x.value, env)), nast.keywords))

    # code.InteractiveConsole({'fn': fn}).interact()

    # chainer.functions の関数とかは、ここでf**kをかける。
    if fn in Func2NodeClass.keys():
        return Func2NodeClass[fn].call(args, keywords, env)

    dprint(fn, fn.__class__)
    if isinstance(fn, types.FunctionType):
        fn = User_Defined_Function(fn)
    elif isinstance(fn, types.MethodType):
        # apply はforwardにする
        # code.InteractiveConsole({'fn': fn}).interact()
        if fn.__func__ == chainer.FunctionNode.apply:
            fn = User_Defined_Func_In_Link(
                fn.__self__, fn.__self__.forward)
        elif fn.__func__ == chainer.FunctionNode.retain_inputs:
            # TODO(satos) これbackward側に何か伝える必要がありそう
            fn = Func(lambda _, __, ___: None)
        else:
            fn = User_Defined_Func_In_Link(fn.__self__, fn)
    elif fn in builtin_functions:
        fn = builtin_functions[fn]
    elif isinstance(fn, type):
        # なにがしかのinstanceを作成したはず
        assert fn.__module__ != 'builtins'
        fn = User_Defined_Class(fn).init_wrapper
    elif isinstance(fn, chainer.link.Link):
        fn = convert_link(fn, env)

    dprint('converted to', fn)
    return fn.call(args, keywords, env)
示例#4
0
 def __init__(self, func):
     self.func = func
     src = clip_head(inspect.getsource(func))
     dprint(src)
     self.ast = gast.ast_to_gast(ast.parse(src)).body[0]
     assert (isinstance(self.ast, gast.gast.FunctionDef))
示例#5
0
def compile_model(model, inputs):
    # return helper.make_graph([],'dummy',[],[])

    init_id2name(model)
    # code.InteractiveConsole({'mo': model}).interact()
    env = Env(sys.modules[model.__module__])
    molk = User_Defined_Link(model, env)

    input_tensors = []
    for i in inputs:
        # TODO(hamaji): Set valid type info.
        if isinstance(i, (list, tuple)):
            x = new_sequence()
        elif i is None:
            x = new_tensor()
        else:
            if isinstance(i, int):
                i = np.array(i)
            else:
                # TODO(durswd): This code requires chainer6.x
                i = chainer.cuda.to_cpu(i)

            x = new_tensor(dims=i.shape, dtype=i.dtype)
        input_tensors.append(x)

    input_values = [Value(i) for i in input_tensors]
    v = molk.call(input_values, [], env)

    dprint('output_tensors', v)
    if isinstance(v.value, tuple):
        output_tensors = list(v.value)  # ばらしてみる
    else:
        output_tensors = [v]  # とりあえず1tensor

    # print('env.init_tensors ',env.init_tensors)
    input_tensors += list(env.init_tensors.values())

    for f in env.restore_funcs:
        f()

    # for no in env.nodes:
    #   print(no.op_type)
    # print(env.nodes)
    # print(input_tensors)
    # print(output_tensors)
    # for ch in model.namedparams():
    #    print(ch)

    outputs_vi = [o.to_value_info(env) for o in output_tensors]
    graph = make_graph(env.nodes, 'name_is_unknown_now', input_tensors,
                       outputs_vi)

    # inputのうち、重みであるものにはinitializerをつける
    # batch_sizeやinput_sizeなどの可変なものはできる限りのそのままで

    # Chainer compiler 独自のノードを使うとcheckできなくなる...
    # checker.check_graph(graph)
    mo = helper.make_model(graph)

    # print(mo)
    return mo