示例#1
0
    def __init__(self, ch, env):
        src = clip_head(inspect.getsource(ch.forward))
        dprint(src)
        self.ast = gast.ast_to_gast(ast.parse(src)).body[0]

        self.call = User_Defined_Func_In_Link(ch, ch.forward).call

        # 以下、 最初の外からのためのやつ
        # code.InteractiveConsole({'v': self.ast}).interact()
        self.forward_arglen = len(self.ast.args.args) - 1

        # ここで、初期化したやつを上書きしてやる必要が出てくる
        # あとでchainerで実行するために回復しないといけないので、
        # restore_funcs に復元すべきものを追加している
        self.inits = []

        for s, v in ch.namedparams():
            s = s[1:]
            if s.find('/') != -1:
                continue
            t = helper.make_tensor_value_info('/' + s, TensorProto.FLOAT,
                                              list(v.shape))
            self.inits.append(t)
            mv = getattr(ch, s)
            setattr(ch, s, t)
            env.restore_funcs.append(lambda: setattr(ch, s, mv))

        # TODO(satos) Yieldをコンパイルできるとこれを消せる
        mv = getattr(ch, 'children')
        setattr(ch, 'children', Func(lambda _, __, ___: mv()))
        env.restore_funcs.append(lambda: setattr(ch, 'children', mv))
示例#2
0
def eval_attribute(nast, env):
    body = eval_ast(nast.value, env)

    if not body.is_py:
        if nast.attr == 'shape':
            res = env.calc(
                'Shape',
                inputs=[body.to_tensor(env).name],
                npdtype=np.int64,
            )
            res = env.calc_seq(
                'ChainerSequenceSeparate',
                inputs=[res.name],
            )
            return res

        elif nast.attr == 'size':
            res = env.calc(
                'Size',
                inputs=[body.to_tensor(env).name],
                npdtype=np.int64,
            )
            return res

        elif nast.attr == 'append':
            # TODO(satos) ごまかさない
            assert isinstance(
                nast.value,
                gast.Name) and nast.value.id in env.get_var_dict().keys()
            na = nast.value.id

            # あと、ここのnaがreferenceの場合不正確
            # たとえば
            # x = y
            # x.append(3)
            # のyが更新されないので問題

            def f(args, _, env):
                assert len(args) == 1
                v = args[0].to_tensor(env)
                env.set_var(
                    na,
                    _value(
                        env.calc_seq(
                            'ChainerSequenceAppend',
                            inputs=[body.to_sequence(env).name, v.name],
                        )))
                return None

            return Func(f)

        raise Exception('Unimplemented attribute ', nast.attr, ' for tensor')
    return body.get_attribute(nast.attr, env)
示例#3
0
def eval_call(nast, env):
    fn = eval_ast(nast.func, env)
    if not fn.is_py:
        raise TypeError('Expected a callable: %s' % fn.value)
    fn = fn.value

    # TODO(hamaji): Merge this logic with is_print_logging. Also,
    # maybe it's better to try emitting ChainerPrint.
    if fn in (logging.debug, logging.info,
              logging.warn, logging.warning, logging.error):
        return None

    args = []
    for ag in nast.args:
        if isinstance(ag, gast.Starred):
            args += list(eval_ast(ag.value, env))
        else:
            args.append(eval_ast(ag, env))

    keywords = dict(
        map(lambda x: (x.arg, eval_ast(x.value, env)), nast.keywords))

    # code.InteractiveConsole({'fn': fn}).interact()

    # chainer.functions の関数とかは、ここでf**kをかける。
    if fn in Func2NodeClass.keys():
        return Func2NodeClass[fn].call(args, keywords, env)

    dprint(fn, fn.__class__)
    if isinstance(fn, types.FunctionType):
        fn = User_Defined_Function(fn)
    elif isinstance(fn, types.MethodType):
        # apply はforwardにする
        # code.InteractiveConsole({'fn': fn}).interact()
        if fn.__func__ == chainer.FunctionNode.apply:
            fn = User_Defined_Func_In_Link(
                fn.__self__, fn.__self__.forward)
        elif fn.__func__ == chainer.FunctionNode.retain_inputs:
            # TODO(satos) これbackward側に何か伝える必要がありそう
            fn = Func(lambda _, __, ___: None)
        else:
            fn = User_Defined_Func_In_Link(fn.__self__, fn)
    elif fn in builtin_functions:
        fn = builtin_functions[fn]
    elif isinstance(fn, type):
        # なにがしかのinstanceを作成したはず
        assert fn.__module__ != 'builtins'
        fn = User_Defined_Class(fn).init_wrapper
    elif isinstance(fn, chainer.link.Link):
        fn = convert_link(fn, env)

    dprint('converted to', fn)
    return fn.call(args, keywords, env)
示例#4
0
    def __init__(self, classtype):
        # classtypeのmethod は持ってるが init は呼ばれてない、というobjectが必要になる。
        # ので、あえて parent のinit を呼ばない継承をする
        class Tmp(classtype):
            def __init__(_):
                pass

        # dprint('user defined class of',classtype)
        ch = Tmp()
        ch.__module__ = classtype.__module__

        # code.InteractiveConsole({'Tmp': Tmp,'v': ch}).interact()
        def f(args, kwargs, env):
            if not isinstance(classtype.__init__, type(str.__init__)):  # slot wrapper というものらしい
                User_Defined_Func_In_Link(
                    ch, classtype.__init__).call(args, kwargs, env)

            return ch

        self.init_wrapper = Func(f)
示例#5
0
            "ChainerGenericLen",
            inputs=[x.name],
        )


class Builtin_List(Callable):
    def __init__(self):
        super(Builtin_List, self).__init__(lambda x: x)

    def call_impl(self, env, x):
        return env.calc(
            "Identity",
            inputs=[x.to_sequence(env).name],
        )


def builtin_range(args, _, env):
    if all(a.is_py for a in args):
        # print('constant loop',args)
        return range(*(a.value for a in args))

    return env.calc_seq('ChainerSequenceRange',
                        inputs=[a.to_tensor(env).name for a in args])


builtin_functions = {
    len: Builtin_Len(),
    list: Builtin_List(),
    range: Func(builtin_range),
}