def match_types(tys1, tys2): subst = {} for t1, t2 in zip(tys1, tys2): t1 = apply_subst(subst, t1) t2 = apply_subst(subst, t2) utils.add_dict(subst, match_type(t1, t2)) return subst
def infer_ListComp(self, node): # ListComp(expr elt, comprehension* generators) # cannot think of cases where len > 2 assert len(node.generators) == 1 gen = node.generators[0] # TODO: handle cases where len(gen.ifs) > 0 assert len(gen.ifs) == 0 tc = copy_InferenceEngine(self) ty_iteration = tc.infer_expr(gen.iter) ty_i = tc.generate_fresh_TyVar(gen.target) if isinstance(ty_iteration, TyTensor): ty_i_ = TyTensor(ty_iteration.dtype, ty_iteration.kind, ty_iteration.ndim - 1, shape=ty_iteration.shape[1:]) if ty_iteration.shape is not None: ty_i_.shape = ty_iteration.shape[1:] unify(ty_i, ty_i_) else: unify(TySequence(ty_i, None), ty_iteration) tc.infer_expr(node.elt) utils.add_dict(self.nodetype, tc.nodetype) utils.add_dict(self.subroutine_node, tc.subroutine_node) self.nodetype[node] = TyList(tc.nodetype[node.elt]) return self.nodetype[node]
def infer_user_defined_function(self, func, ty_args, node): if isinstance(func, types.FunctionType) or \ isinstance(func, types.MethodType): func_body = func if isinstance(node.func, gast.Attribute): ty_self = self.nodetype[node.func.value] ty_args = [ty_self] + ty_args else: # defined with __call__ if isinstance(func, chainer.Chain): func_body = func.forward else: func_body = func.__call__ ty_self = type_of_value(func) ty_args = [ty_self] + ty_args code = clip_head(inspect.getsource(func_body)) # FunctionDef of called subroutine func_node = gast.ast_to_gast(ast.parse(code)).body[0] self.subroutine_node[node] = func_node tc = InferenceEngine(is_debug=self.is_debug, module=sys.modules[func.__module__]) tc.infer_function(func_node, ty_args, type_hints=typing.get_type_hints(func_body)) # copy nodetype and subroutine_node from subroutine utils.add_dict(self.nodetype, tc.nodetype) utils.add_dict(self.subroutine_node, tc.subroutine_node) return ty_args, tc.nodetype[func_node]
def infer_For(self, node): # For(expr target, expr iter, stmt* body, stmt* orelse) assert isinstance(node.target, (gast.Name, gast.Tuple)) ty_iteration = self.infer_expr(node.iter) ty_i = self.infer_expr(node.target) if isinstance(ty_iteration, TyTensor): unify(ty_i, TyTensor(ty_iteration.kind, ty_iteration.dtype, ty_iteration.shape[1:])) elif isinstance(ty_iteration, TyList): unify(ty_iteration, TyList(ty_i)) else: unify(ty_iteration, TyTuple(ty_i)) for _ in range(2): tc = copy_InferenceEngine(self) self.infer_block(tc, node.body) utils.add_dict(self.nodetype, tc.nodetype) utils.add_dict(self.subroutine_node, tc.subroutine_node)
def infer_If(self, node): # If(expr test, stmt* body, stmt* orelse) # XXX: type of node.test can be anything self.infer_expr(node.test) x = lazy_initializer(node) if node.orelse == []: tc = copy_InferenceEngine(self) self.infer_block(tc, node.body) utils.add_dict(self.nodetype, tc.nodetype) utils.add_dict(self.subroutine_node, tc.subroutine_node) else: tc1 = copy_InferenceEngine(self) tc2 = copy_InferenceEngine(self) self.infer_2blocks(tc1, tc2, node.body, node.orelse) utils.add_dict(self.nodetype, tc1.nodetype) utils.add_dict(self.nodetype, tc2.nodetype) utils.add_dict(self.subroutine_node, tc1.subroutine_node) utils.add_dict(self.subroutine_node, tc2.subroutine_node) if x is not None: self.infer_expr(x).is_optional = False
def infer_If(self, node): # If(expr test, stmt* body, stmt* orelse) # XXX: type of node.test can be anything self.infer_expr(node.test) x = is_isNone(node.test) if node.orelse == []: tc = copy_InferenceEngine(self) if x is not None: self.split_optional(tc, self, x) ty_ret = self.infer_block(tc, node.body) utils.add_dict(self.nodetype, tc.nodetype) utils.add_dict(self.subroutine_node, tc.subroutine_node) return ty_ret tc1 = copy_InferenceEngine(self) tc2 = copy_InferenceEngine(self) if x is not None: self.split_optional(tc1, tc2, x) ty_ret = self.infer_2blocks(tc1, tc2, node.body, node.orelse) utils.add_dict(self.nodetype, tc1.nodetype) utils.add_dict(self.nodetype, tc2.nodetype) utils.add_dict(self.subroutine_node, tc1.subroutine_node) utils.add_dict(self.subroutine_node, tc2.subroutine_node) return ty_ret