def testInstanceToType(self): class MyClass(object): def method(self): pass test_cases = [ (typehints.Dict[str, int], { 'a': 1 }), (typehints.Dict[str, typehints.Union[str, int]], { 'a': 1, 'b': 'c' }), (typehints.Dict[typehints.Any, typehints.Any], {}), (typehints.Set[str], {'a'}), (typehints.Set[typehints.Union[str, float]], {'a', 0.4}), (typehints.Set[typehints.Any], set()), (typehints.FrozenSet[str], frozenset(['a'])), ( typehints.FrozenSet[typehints.Union[str, float]], frozenset(['a', 0.4])), (typehints.FrozenSet[typehints.Any], frozenset()), (typehints.Tuple[int], (1, )), (typehints.Tuple[int, int, str], (1, 2, '3')), (typehints.Tuple[()], ()), (typehints.List[int], [1]), (typehints.List[typehints.Union[int, str]], [1, 'a']), (typehints.List[typehints.Any], []), (type(None), None), (type(MyClass), MyClass), (MyClass, MyClass()), (type(MyClass.method), MyClass.method), (types.MethodType, MyClass().method), (row_type.RowTypeConstraint([('x', int)]), beam.Row(x=37)), ] for expected_type, instance in test_cases: self.assertEqual( expected_type, trivial_inference.instance_to_type(instance), msg=instance)
def infer_return_type_func(f, input_types, debug=False, depth=0): """Analyses a function to deduce its return type. Args: f: A Python function object to infer the return type of. input_types: A sequence of inputs corresponding to the input types. debug: Whether to print verbose debugging information. depth: Maximum inspection depth during type inference. Returns: A TypeConstraint that that the return value of this function will (likely) satisfy given the specified inputs. Raises: TypeInferenceError: if no type can be inferred. """ if debug: print() print(f, id(f), input_types) dis.dis(f) from . import opcodes simple_ops = dict((k.upper(), v) for k, v in opcodes.__dict__.items()) co = f.__code__ code = co.co_code end = len(code) pc = 0 extended_arg = 0 # Python 2 only. free = None yields = set() returns = set() # TODO(robertwb): Default args via inspect module. local_vars = list(input_types) + [typehints.Union[ ()]] * (len(co.co_varnames) - len(input_types)) state = FrameState(f, local_vars) states = collections.defaultdict(lambda: None) jumps = collections.defaultdict(int) # In Python 3, use dis library functions to disassemble bytecode and handle # EXTENDED_ARGs. is_py3 = sys.version_info[0] == 3 if is_py3: ofs_table = {} # offset -> instruction for instruction in dis.get_instructions(f): ofs_table[instruction.offset] = instruction # Python 2 - 3.5: 1 byte opcode + optional 2 byte arg (1 or 3 bytes). # Python 3.6+: 1 byte opcode + 1 byte arg (2 bytes, arg may be ignored). if sys.version_info >= (3, 6): inst_size = 2 opt_arg_size = 0 else: inst_size = 1 opt_arg_size = 2 last_pc = -1 while pc < end: # pylint: disable=too-many-nested-blocks start = pc if is_py3: instruction = ofs_table[pc] op = instruction.opcode else: op = ord(code[pc]) if debug: print('-->' if pc == last_pc else ' ', end=' ') print(repr(pc).rjust(4), end=' ') print(dis.opname[op].ljust(20), end=' ') pc += inst_size if op >= dis.HAVE_ARGUMENT: if is_py3: arg = instruction.arg else: arg = ord(code[pc]) + ord(code[pc + 1]) * 256 + extended_arg extended_arg = 0 pc += opt_arg_size if op == dis.EXTENDED_ARG: extended_arg = arg * 65536 if debug: print(str(arg).rjust(5), end=' ') if op in dis.hasconst: print('(' + repr(co.co_consts[arg]) + ')', end=' ') elif op in dis.hasname: print('(' + co.co_names[arg] + ')', end=' ') elif op in dis.hasjrel: print('(to ' + repr(pc + arg) + ')', end=' ') elif op in dis.haslocal: print('(' + co.co_varnames[arg] + ')', end=' ') elif op in dis.hascompare: print('(' + dis.cmp_op[arg] + ')', end=' ') elif op in dis.hasfree: if free is None: free = co.co_cellvars + co.co_freevars print('(' + free[arg] + ')', end=' ') # Actually emulate the op. if state is None and states[start] is None: # No control reaches here (yet). if debug: print() continue state |= states[start] opname = dis.opname[op] jmp = jmp_state = None if opname.startswith('CALL_FUNCTION'): if sys.version_info < (3, 6): # Each keyword takes up two arguments on the stack (name and value). standard_args = (arg & 0xFF) + 2 * (arg >> 8) var_args = 'VAR' in opname kw_args = 'KW' in opname pop_count = standard_args + var_args + kw_args + 1 if depth <= 0: return_type = Any elif arg >> 8: if not var_args and not kw_args and not arg & 0xFF: # Keywords only, maybe it's a call to Row. if isinstance(state.stack[-pop_count], Const): from apache_beam.pvalue import Row if state.stack[-pop_count].value == Row: fields = state.stack[-pop_count + 1::2] types = state.stack[-pop_count + 2::2] return_type = row_type.RowTypeConstraint( zip([fld.value for fld in fields], Const.unwrap_all(types))) else: return_type = Any else: # TODO(robertwb): Handle this case. return_type = Any elif isinstance(state.stack[-pop_count], Const): # TODO(robertwb): Handle this better. if var_args or kw_args: state.stack[-1] = Any state.stack[-var_args - kw_args] = Any return_type = infer_return_type( state.stack[-pop_count].value, state.stack[1 - pop_count:], debug=debug, depth=depth - 1) else: return_type = Any state.stack[-pop_count:] = [return_type] else: # Python 3.6+ if opname == 'CALL_FUNCTION': pop_count = arg + 1 if depth <= 0: return_type = Any elif isinstance(state.stack[-pop_count], Const): return_type = infer_return_type( state.stack[-pop_count].value, state.stack[1 - pop_count:], debug=debug, depth=depth - 1) else: return_type = Any elif opname == 'CALL_FUNCTION_KW': # TODO(udim): Handle keyword arguments. Requires passing them by name # to infer_return_type. pop_count = arg + 2 if isinstance(state.stack[-pop_count], Const): from apache_beam.pvalue import Row if state.stack[-pop_count].value == Row: fields = state.stack[-1].value return_type = row_type.RowTypeConstraint( zip( fields, Const.unwrap_all(state.stack[-pop_count + 1:-1]))) else: return_type = Any else: return_type = Any elif opname == 'CALL_FUNCTION_EX': # stack[-has_kwargs]: Map of keyword args. # stack[-1 - has_kwargs]: Iterable of positional args. # stack[-2 - has_kwargs]: Function to call. has_kwargs = arg & 1 # type: int pop_count = has_kwargs + 2 if has_kwargs: # TODO(udim): Unimplemented. Requires same functionality as a # CALL_FUNCTION_KW implementation. return_type = Any else: args = state.stack[-1] _callable = state.stack[-2] if isinstance(args, typehints.ListConstraint): # Case where there's a single var_arg argument. args = [args] elif isinstance(args, typehints.TupleConstraint): args = list(args._inner_types()) return_type = infer_return_type(_callable.value, args, debug=debug, depth=depth - 1) else: raise TypeInferenceError('unable to handle %s' % opname) state.stack[-pop_count:] = [return_type] elif opname == 'CALL_METHOD': pop_count = 1 + arg # LOAD_METHOD will return a non-Const (Any) if loading from an Any. if isinstance(state.stack[-pop_count], Const) and depth > 0: return_type = infer_return_type(state.stack[-pop_count].value, state.stack[1 - pop_count:], debug=debug, depth=depth - 1) else: return_type = typehints.Any state.stack[-pop_count:] = [return_type] elif opname in simple_ops: if debug: print("Executing simple op " + opname) simple_ops[opname](state, arg) elif opname == 'RETURN_VALUE': returns.add(state.stack[-1]) state = None elif opname == 'YIELD_VALUE': yields.add(state.stack[-1]) elif opname == 'JUMP_FORWARD': jmp = pc + arg jmp_state = state state = None elif opname == 'JUMP_ABSOLUTE': jmp = arg jmp_state = state state = None elif opname in ('POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE'): state.stack.pop() jmp = arg jmp_state = state.copy() elif opname in ('JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP'): jmp = arg jmp_state = state.copy() state.stack.pop() elif opname == 'FOR_ITER': jmp = pc + arg jmp_state = state.copy() jmp_state.stack.pop() state.stack.append(element_type(state.stack[-1])) else: raise TypeInferenceError('unable to handle %s' % opname) if jmp is not None: # TODO(robertwb): Is this guaranteed to converge? new_state = states[jmp] | jmp_state if jmp < pc and new_state != states[jmp] and jumps[pc] < 5: jumps[pc] += 1 pc = jmp states[jmp] = new_state if debug: print() print(state) pprint.pprint(dict(item for item in states.items() if item[1])) if yields: result = typehints.Iterable[reduce(union, Const.unwrap_all(yields))] else: result = reduce(union, Const.unwrap_all(returns)) finalize_hints(result) if debug: print(f, id(f), input_types, '->', result) return result
def testRowAttr(self): self.assertReturnType( typehints.Tuple[int, str], lambda row: (row.x, getattr(row, 'y')), [row_type.RowTypeConstraint([('x', int), ('y', str)])])