def __init__(self, ch, env): src = clip_head(inspect.getsource(ch.forward)) dprint(src) self.ast = gast.ast_to_gast(ast.parse(src)).body[0] self.call = User_Defined_Func_In_Link(ch, ch.forward).call # 以下、 最初の外からのためのやつ # code.InteractiveConsole({'v': self.ast}).interact() self.forward_arglen = len(self.ast.args.args) - 1 # ここで、初期化したやつを上書きしてやる必要が出てくる # あとでchainerで実行するために回復しないといけないので、 # restore_funcs に復元すべきものを追加している self.inits = [] for s, v in ch.namedparams(): s = s[1:] if s.find('/') != -1: continue t = helper.make_tensor_value_info('/' + s, TensorProto.FLOAT, list(v.shape)) self.inits.append(t) mv = getattr(ch, s) setattr(ch, s, t) env.restore_funcs.append(lambda: setattr(ch, s, mv)) # TODO(satos) Yieldをコンパイルできるとこれを消せる mv = getattr(ch, 'children') setattr(ch, 'children', Func(lambda _, __, ___: mv())) env.restore_funcs.append(lambda: setattr(ch, 'children', mv))
def eval_attribute(nast, env): body = eval_ast(nast.value, env) if not body.is_py: if nast.attr == 'shape': res = env.calc( 'Shape', inputs=[body.to_tensor(env).name], npdtype=np.int64, ) res = env.calc_seq( 'ChainerSequenceSeparate', inputs=[res.name], ) return res elif nast.attr == 'size': res = env.calc( 'Size', inputs=[body.to_tensor(env).name], npdtype=np.int64, ) return res elif nast.attr == 'append': # TODO(satos) ごまかさない assert isinstance( nast.value, gast.Name) and nast.value.id in env.get_var_dict().keys() na = nast.value.id # あと、ここのnaがreferenceの場合不正確 # たとえば # x = y # x.append(3) # のyが更新されないので問題 def f(args, _, env): assert len(args) == 1 v = args[0].to_tensor(env) env.set_var( na, _value( env.calc_seq( 'ChainerSequenceAppend', inputs=[body.to_sequence(env).name, v.name], ))) return None return Func(f) raise Exception('Unimplemented attribute ', nast.attr, ' for tensor') return body.get_attribute(nast.attr, env)
def eval_call(nast, env): fn = eval_ast(nast.func, env) if not fn.is_py: raise TypeError('Expected a callable: %s' % fn.value) fn = fn.value # TODO(hamaji): Merge this logic with is_print_logging. Also, # maybe it's better to try emitting ChainerPrint. if fn in (logging.debug, logging.info, logging.warn, logging.warning, logging.error): return None args = [] for ag in nast.args: if isinstance(ag, gast.Starred): args += list(eval_ast(ag.value, env)) else: args.append(eval_ast(ag, env)) keywords = dict( map(lambda x: (x.arg, eval_ast(x.value, env)), nast.keywords)) # code.InteractiveConsole({'fn': fn}).interact() # chainer.functions の関数とかは、ここでf**kをかける。 if fn in Func2NodeClass.keys(): return Func2NodeClass[fn].call(args, keywords, env) dprint(fn, fn.__class__) if isinstance(fn, types.FunctionType): fn = User_Defined_Function(fn) elif isinstance(fn, types.MethodType): # apply はforwardにする # code.InteractiveConsole({'fn': fn}).interact() if fn.__func__ == chainer.FunctionNode.apply: fn = User_Defined_Func_In_Link( fn.__self__, fn.__self__.forward) elif fn.__func__ == chainer.FunctionNode.retain_inputs: # TODO(satos) これbackward側に何か伝える必要がありそう fn = Func(lambda _, __, ___: None) else: fn = User_Defined_Func_In_Link(fn.__self__, fn) elif fn in builtin_functions: fn = builtin_functions[fn] elif isinstance(fn, type): # なにがしかのinstanceを作成したはず assert fn.__module__ != 'builtins' fn = User_Defined_Class(fn).init_wrapper elif isinstance(fn, chainer.link.Link): fn = convert_link(fn, env) dprint('converted to', fn) return fn.call(args, keywords, env)
def __init__(self, classtype): # classtypeのmethod は持ってるが init は呼ばれてない、というobjectが必要になる。 # ので、あえて parent のinit を呼ばない継承をする class Tmp(classtype): def __init__(_): pass # dprint('user defined class of',classtype) ch = Tmp() ch.__module__ = classtype.__module__ # code.InteractiveConsole({'Tmp': Tmp,'v': ch}).interact() def f(args, kwargs, env): if not isinstance(classtype.__init__, type(str.__init__)): # slot wrapper というものらしい User_Defined_Func_In_Link( ch, classtype.__init__).call(args, kwargs, env) return ch self.init_wrapper = Func(f)
"ChainerGenericLen", inputs=[x.name], ) class Builtin_List(Callable): def __init__(self): super(Builtin_List, self).__init__(lambda x: x) def call_impl(self, env, x): return env.calc( "Identity", inputs=[x.to_sequence(env).name], ) def builtin_range(args, _, env): if all(a.is_py for a in args): # print('constant loop',args) return range(*(a.value for a in args)) return env.calc_seq('ChainerSequenceRange', inputs=[a.to_tensor(env).name for a in args]) builtin_functions = { len: Builtin_Len(), list: Builtin_List(), range: Func(builtin_range), }