Ejemplo n.º 1
0
  def testInstanceToType(self):
    class MyClass(object):
      def method(self):
        pass

    test_cases = [
        (typehints.Dict[str, int], {
            'a': 1
        }),
        (typehints.Dict[str, typehints.Union[str, int]], {
            'a': 1, 'b': 'c'
        }),
        (typehints.Dict[typehints.Any, typehints.Any], {}),
        (typehints.Set[str], {'a'}),
        (typehints.Set[typehints.Union[str, float]], {'a', 0.4}),
        (typehints.Set[typehints.Any], set()),
        (typehints.FrozenSet[str], frozenset(['a'])),
        (
            typehints.FrozenSet[typehints.Union[str, float]],
            frozenset(['a', 0.4])),
        (typehints.FrozenSet[typehints.Any], frozenset()),
        (typehints.Tuple[int], (1, )),
        (typehints.Tuple[int, int, str], (1, 2, '3')),
        (typehints.Tuple[()], ()),
        (typehints.List[int], [1]),
        (typehints.List[typehints.Union[int, str]], [1, 'a']),
        (typehints.List[typehints.Any], []),
        (type(None), None),
        (type(MyClass), MyClass),
        (MyClass, MyClass()),
        (type(MyClass.method), MyClass.method),
        (types.MethodType, MyClass().method),
        (row_type.RowTypeConstraint([('x', int)]), beam.Row(x=37)),
    ]
    for expected_type, instance in test_cases:
      self.assertEqual(
          expected_type,
          trivial_inference.instance_to_type(instance),
          msg=instance)
Ejemplo n.º 2
0
def infer_return_type_func(f, input_types, debug=False, depth=0):
    """Analyses a function to deduce its return type.

  Args:
    f: A Python function object to infer the return type of.
    input_types: A sequence of inputs corresponding to the input types.
    debug: Whether to print verbose debugging information.
    depth: Maximum inspection depth during type inference.

  Returns:
    A TypeConstraint that that the return value of this function will (likely)
    satisfy given the specified inputs.

  Raises:
    TypeInferenceError: if no type can be inferred.
  """
    if debug:
        print()
        print(f, id(f), input_types)
        dis.dis(f)
    from . import opcodes
    simple_ops = dict((k.upper(), v) for k, v in opcodes.__dict__.items())

    co = f.__code__
    code = co.co_code
    end = len(code)
    pc = 0
    extended_arg = 0  # Python 2 only.
    free = None

    yields = set()
    returns = set()
    # TODO(robertwb): Default args via inspect module.
    local_vars = list(input_types) + [typehints.Union[
        ()]] * (len(co.co_varnames) - len(input_types))
    state = FrameState(f, local_vars)
    states = collections.defaultdict(lambda: None)
    jumps = collections.defaultdict(int)

    # In Python 3, use dis library functions to disassemble bytecode and handle
    # EXTENDED_ARGs.
    is_py3 = sys.version_info[0] == 3
    if is_py3:
        ofs_table = {}  # offset -> instruction
        for instruction in dis.get_instructions(f):
            ofs_table[instruction.offset] = instruction

    # Python 2 - 3.5: 1 byte opcode + optional 2 byte arg (1 or 3 bytes).
    # Python 3.6+: 1 byte opcode + 1 byte arg (2 bytes, arg may be ignored).
    if sys.version_info >= (3, 6):
        inst_size = 2
        opt_arg_size = 0
    else:
        inst_size = 1
        opt_arg_size = 2

    last_pc = -1
    while pc < end:  # pylint: disable=too-many-nested-blocks
        start = pc
        if is_py3:
            instruction = ofs_table[pc]
            op = instruction.opcode
        else:
            op = ord(code[pc])
        if debug:
            print('-->' if pc == last_pc else '    ', end=' ')
            print(repr(pc).rjust(4), end=' ')
            print(dis.opname[op].ljust(20), end=' ')

        pc += inst_size
        if op >= dis.HAVE_ARGUMENT:
            if is_py3:
                arg = instruction.arg
            else:
                arg = ord(code[pc]) + ord(code[pc + 1]) * 256 + extended_arg
            extended_arg = 0
            pc += opt_arg_size
            if op == dis.EXTENDED_ARG:
                extended_arg = arg * 65536
            if debug:
                print(str(arg).rjust(5), end=' ')
                if op in dis.hasconst:
                    print('(' + repr(co.co_consts[arg]) + ')', end=' ')
                elif op in dis.hasname:
                    print('(' + co.co_names[arg] + ')', end=' ')
                elif op in dis.hasjrel:
                    print('(to ' + repr(pc + arg) + ')', end=' ')
                elif op in dis.haslocal:
                    print('(' + co.co_varnames[arg] + ')', end=' ')
                elif op in dis.hascompare:
                    print('(' + dis.cmp_op[arg] + ')', end=' ')
                elif op in dis.hasfree:
                    if free is None:
                        free = co.co_cellvars + co.co_freevars
                    print('(' + free[arg] + ')', end=' ')

        # Actually emulate the op.
        if state is None and states[start] is None:
            # No control reaches here (yet).
            if debug:
                print()
            continue
        state |= states[start]

        opname = dis.opname[op]
        jmp = jmp_state = None
        if opname.startswith('CALL_FUNCTION'):
            if sys.version_info < (3, 6):
                # Each keyword takes up two arguments on the stack (name and value).
                standard_args = (arg & 0xFF) + 2 * (arg >> 8)
                var_args = 'VAR' in opname
                kw_args = 'KW' in opname
                pop_count = standard_args + var_args + kw_args + 1
                if depth <= 0:
                    return_type = Any
                elif arg >> 8:
                    if not var_args and not kw_args and not arg & 0xFF:
                        # Keywords only, maybe it's a call to Row.
                        if isinstance(state.stack[-pop_count], Const):
                            from apache_beam.pvalue import Row
                            if state.stack[-pop_count].value == Row:
                                fields = state.stack[-pop_count + 1::2]
                                types = state.stack[-pop_count + 2::2]
                                return_type = row_type.RowTypeConstraint(
                                    zip([fld.value for fld in fields],
                                        Const.unwrap_all(types)))
                            else:
                                return_type = Any
                    else:
                        # TODO(robertwb): Handle this case.
                        return_type = Any
                elif isinstance(state.stack[-pop_count], Const):
                    # TODO(robertwb): Handle this better.
                    if var_args or kw_args:
                        state.stack[-1] = Any
                        state.stack[-var_args - kw_args] = Any
                    return_type = infer_return_type(
                        state.stack[-pop_count].value,
                        state.stack[1 - pop_count:],
                        debug=debug,
                        depth=depth - 1)
                else:
                    return_type = Any
                state.stack[-pop_count:] = [return_type]
            else:  # Python 3.6+
                if opname == 'CALL_FUNCTION':
                    pop_count = arg + 1
                    if depth <= 0:
                        return_type = Any
                    elif isinstance(state.stack[-pop_count], Const):
                        return_type = infer_return_type(
                            state.stack[-pop_count].value,
                            state.stack[1 - pop_count:],
                            debug=debug,
                            depth=depth - 1)
                    else:
                        return_type = Any
                elif opname == 'CALL_FUNCTION_KW':
                    # TODO(udim): Handle keyword arguments. Requires passing them by name
                    #   to infer_return_type.
                    pop_count = arg + 2
                    if isinstance(state.stack[-pop_count], Const):
                        from apache_beam.pvalue import Row
                        if state.stack[-pop_count].value == Row:
                            fields = state.stack[-1].value
                            return_type = row_type.RowTypeConstraint(
                                zip(
                                    fields,
                                    Const.unwrap_all(state.stack[-pop_count +
                                                                 1:-1])))
                        else:
                            return_type = Any
                    else:
                        return_type = Any
                elif opname == 'CALL_FUNCTION_EX':
                    # stack[-has_kwargs]: Map of keyword args.
                    # stack[-1 - has_kwargs]: Iterable of positional args.
                    # stack[-2 - has_kwargs]: Function to call.
                    has_kwargs = arg & 1  # type: int
                    pop_count = has_kwargs + 2
                    if has_kwargs:
                        # TODO(udim): Unimplemented. Requires same functionality as a
                        #   CALL_FUNCTION_KW implementation.
                        return_type = Any
                    else:
                        args = state.stack[-1]
                        _callable = state.stack[-2]
                        if isinstance(args, typehints.ListConstraint):
                            # Case where there's a single var_arg argument.
                            args = [args]
                        elif isinstance(args, typehints.TupleConstraint):
                            args = list(args._inner_types())
                        return_type = infer_return_type(_callable.value,
                                                        args,
                                                        debug=debug,
                                                        depth=depth - 1)
                else:
                    raise TypeInferenceError('unable to handle %s' % opname)
                state.stack[-pop_count:] = [return_type]
        elif opname == 'CALL_METHOD':
            pop_count = 1 + arg
            # LOAD_METHOD will return a non-Const (Any) if loading from an Any.
            if isinstance(state.stack[-pop_count], Const) and depth > 0:
                return_type = infer_return_type(state.stack[-pop_count].value,
                                                state.stack[1 - pop_count:],
                                                debug=debug,
                                                depth=depth - 1)
            else:
                return_type = typehints.Any
            state.stack[-pop_count:] = [return_type]
        elif opname in simple_ops:
            if debug:
                print("Executing simple op " + opname)
            simple_ops[opname](state, arg)
        elif opname == 'RETURN_VALUE':
            returns.add(state.stack[-1])
            state = None
        elif opname == 'YIELD_VALUE':
            yields.add(state.stack[-1])
        elif opname == 'JUMP_FORWARD':
            jmp = pc + arg
            jmp_state = state
            state = None
        elif opname == 'JUMP_ABSOLUTE':
            jmp = arg
            jmp_state = state
            state = None
        elif opname in ('POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE'):
            state.stack.pop()
            jmp = arg
            jmp_state = state.copy()
        elif opname in ('JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP'):
            jmp = arg
            jmp_state = state.copy()
            state.stack.pop()
        elif opname == 'FOR_ITER':
            jmp = pc + arg
            jmp_state = state.copy()
            jmp_state.stack.pop()
            state.stack.append(element_type(state.stack[-1]))
        else:
            raise TypeInferenceError('unable to handle %s' % opname)

        if jmp is not None:
            # TODO(robertwb): Is this guaranteed to converge?
            new_state = states[jmp] | jmp_state
            if jmp < pc and new_state != states[jmp] and jumps[pc] < 5:
                jumps[pc] += 1
                pc = jmp
            states[jmp] = new_state

        if debug:
            print()
            print(state)
            pprint.pprint(dict(item for item in states.items() if item[1]))

    if yields:
        result = typehints.Iterable[reduce(union, Const.unwrap_all(yields))]
    else:
        result = reduce(union, Const.unwrap_all(returns))
    finalize_hints(result)

    if debug:
        print(f, id(f), input_types, '->', result)
    return result
 def testRowAttr(self):
     self.assertReturnType(
         typehints.Tuple[int, str], lambda row: (row.x, getattr(row, 'y')),
         [row_type.RowTypeConstraint([('x', int), ('y', str)])])