예제 #1
0
 def testJoinNothingType(self):
     """Test that JoinTypes() removes or collapses 'nothing'."""
     a = pytd.NamedType("a")
     nothing = pytd.NothingType()
     self.assertEquals(utils.JoinTypes([a, nothing]), a)
     self.assertEquals(utils.JoinTypes([nothing]), nothing)
     self.assertEquals(utils.JoinTypes([nothing, nothing]), nothing)
예제 #2
0
    def VisitUnionType(self, union):
        """Given a union type, try to find a simplification by using superclasses.

    This is a lossy optimization that tries to map a list of types to a common
    base type. For example, int and bool are both base classes of int, so it
    would convert "int or bool" to "int".

    Arguments:
      union: A union type.

    Returns:
      A simplified type, if available.
    """
        intersection = self._Expand(union.type_list[0])

        for t in union.type_list[1:]:
            intersection.intersection_update(self._Expand(t))

        # Remove "redundant" superclasses, by removing everything from the tree
        # that's not a leaf. I.e., we don't need "object" if we have more
        # specialized types.
        new_type_list = tuple(cls for cls in intersection
                              if not self._HasSubClassInSet(cls, intersection))

        return utils.JoinTypes(new_type_list)
예제 #3
0
    def VisitFunction(self, f):
        """Shrink a function, by factorizing cartesian products of arguments.

    Greedily groups signatures, looking at the arguments from left to right.
    This algorithm is *not* optimal. But it does the right thing for the
    typical cases.

    Arguments:
      f: An instance of pytd.Function. If this function has more than one
          signature, we will try to combine some of these signatures by
          introducing union types.

    Returns:
      A new, potentially optimized, instance of pytd.Function.
    """
        max_argument_count = max(len(s.params) for s in f.signatures)
        signatures = f.signatures

        for i in xrange(max_argument_count):
            new_sigs = []
            for sig, types in self._GroupByOmittedArg(signatures, i):
                if types:
                    # One or more options for argument <i>:
                    new_params = list(sig.params)
                    new_params[i] = pytd.Parameter(sig.params[i].name,
                                                   utils.JoinTypes(types))
                    sig = sig.Replace(params=tuple(new_params))
                    new_sigs.append(sig)
                else:
                    # Signature doesn't have argument <i>, so we store the original:
                    new_sigs.append(sig)
            signatures = new_sigs

        return f.Replace(signatures=tuple(signatures))
예제 #4
0
 def testJoinTypes(self):
     """Test that JoinTypes() does recursive flattening."""
     n1, n2, n3, n4, n5, n6 = [pytd.NamedType("n%d" % i) for i in xrange(6)]
     # n1 or (n2 or (n3))
     nested1 = pytd.UnionType(
         (n1, pytd.UnionType((n2, pytd.UnionType((n3, ))))))
     # ((n4) or n5) or n6
     nested2 = pytd.UnionType((pytd.UnionType((pytd.UnionType(
         (n4, )), n5)), n6))
     joined = utils.JoinTypes([nested1, nested2])
     self.assertEquals(joined.type_list, (n1, n2, n3, n4, n5, n6))
예제 #5
0
    def VisitUnionType(self, union):
        """Push unions down into containers.

    This collects similar container types in unions and merges them into
    single instances with the union type pushed down to the element_type level.

    Arguments:
      union: A pytd.Union instance. Might appear in a parameter, a return type,
        a constant type, etc.

    Returns:
      A simplified pytd.Union.
    """
        if not any(isinstance(t, pytd.GenericType) for t in union.type_list):
            # Optimization: If we're not going to change anything, return original.
            return union
        union = utils.JoinTypes(union.type_list)  # flatten
        if not isinstance(union, pytd.UnionType):
            union = pytd.UnionType((union, ))
        collect = {}
        for t in union.type_list:
            if isinstance(t, pytd.GenericType):
                if t.base_type in collect:
                    collect[t.base_type] = tuple(
                        utils.JoinTypes([p1, p2])
                        for p1, p2 in zip(collect[t.base_type], t.parameters))
                else:
                    collect[t.base_type] = t.parameters
        result = pytd.NothingType()
        done = set()
        for t in union.type_list:
            if isinstance(t, pytd.GenericType):
                if t.base_type in done:
                    continue  # already added
                add = t.Replace(parameters=collect[t.base_type])
                done.add(t.base_type)
            else:
                add = t
            result = utils.JoinTypes([result, add])
        return result
예제 #6
0
    def VisitFunction(self, f):
        """Merge signatures of a function.

    This groups signatures by arguments and then for each group creates a
    single signature that joins the return values / exceptions using "or".

    Arguments:
      f: A pytd.Function instance

    Returns:
      Function with simplified / combined signatures.
    """
        groups = self._GroupByArguments(f.signatures)

        new_signatures = []
        for stripped_signature, ret_exc in groups.items():
            ret = utils.JoinTypes(ret_exc.return_types)
            exc = tuple(ret_exc.exceptions)

            new_signatures.append(
                stripped_signature.Replace(return_type=ret, exceptions=exc))
        return f.Replace(signatures=tuple(new_signatures))
예제 #7
0
 def testJoinSingleType(self):
     """Test that JoinTypes() returns single types as-is."""
     a = pytd.NamedType("a")
     self.assertEquals(utils.JoinTypes([a]), a)
     self.assertEquals(utils.JoinTypes([a, a]), a)
예제 #8
0
 def testJoinAnythingTypes(self):
     """Test that JoinTypes() simplifies unions containing '?'."""
     types = [pytd.AnythingType(), pytd.NamedType("a")]
     self.assertIsInstance(utils.JoinTypes(types), pytd.AnythingType)
예제 #9
0
 def testJoinEmptyTypesToNothing(self):
     """Test that JoinTypes() simplifies empty unions to 'nothing'."""
     self.assertIsInstance(utils.JoinTypes([]), pytd.NothingType)
예제 #10
0
 def VisitMutableParameter(self, p):
     return pytd.Parameter(p.name, utils.JoinTypes([p.type, p.new_type]))
예제 #11
0
 def VisitUnionType(self, union):
     return utils.JoinTypes(union.type_list)