Пример #1
0
def find_matches(overloads, argtypes, constraints=()):
    """Find all overloads that unify with the given inputs"""
    input = T.Function(*argtypes + [T.TypeVar('R')])
    for func, sig, kwds in overloads:
        assert isinstance(sig, T.Function), sig

        # -------------------------------------------------
        # Error checking
        l1, l2 = len(sig.argtypes), len(argtypes)
        if l1 != l2:
            raise TypeError(
                "Expected %d args, got %d for function %s" % (l1, l2, func))

        # -------------------------------------------------
        # Unification

        equations = list(chain([(input, sig)], constraints))
        broadcasting = [True] * l1

        try:
            result, remaining = unify(equations, broadcasting)
        except error.UnificationError:
            continue
        else:
            dst_sig = result[0]
            yield Overload(dst_sig, sig, func, remaining, kwds)
Пример #2
0
    def test_unify_ellipsis_to_scalar(self):
        # Test that the A... solved in the argument gets
        # propagated to the result
        ds1 = dshape('A..., int32 -> A..., int32')
        ds2 = dshape('int32 -> R')

        # Try with (ds1, ds2)
        [result], constraints = unify([(ds1, ds2)], [True])
        self.assertEqual(str(result), 'int32 -> int32')
        self.assertEqual(constraints, [])

        # Try with (ds2, ds1)
        [result], constraints = unify([(ds2, ds1)], [True])
        self.assertEqual(str(result), 'int32 -> int32')
        # We have one constraint, namely that R must be coercible to int32
        self.assertEqual(len(constraints), 1)
Пример #3
0
    def test_unify_ellipsis_broadcast(self):
        # Test that the A... broadcasting doesn't add "1, "
        # dimensions to the front
        ds1 = dshape('A..., int32 -> A..., int32 -> A..., int32')
        ds2 = dshape('3, int32 -> int32 -> R')

        # Try with (ds1, ds2)
        [result], constraints = unify([(ds1, ds2)], [True])
        self.assertEqual(str(result), '3, int32 -> int32 -> 3, int32')
        self.assertEqual(constraints, [])

        # Try with (ds2, ds1)
        [result], constraints = unify([(ds2, ds1)], [True])
        self.assertEqual(str(result), '3, int32 -> int32 -> 3, int32')
        # We have one constraint
        self.assertEqual(len(constraints), 1)
Пример #4
0
    def test_unify_datashape_promotion2(self):
        # LHS
        s1 = 'A, B, int32'
        s2 = 'B, 10, float32'

        # RHS
        s3 = 'X, Y, int16'
        s4 = 'X, X, Z'

        # Create proper equation
        d1, d2, d3, d4 = dshapes(s1, s2, s3, s4)
        constraints = [(d1, d3), (d2, d4)]

        # What we know from the above equations is:
        #   1) A coerces to X
        #   2) B coerces to Y
        #   3) 10 coerces to X
        #
        # From this we determine that X must be Fixed(10). We must retain
        # type variable B for Y, since we can only say that B must unify with
        # Fixed(10), but not whether it is actually Fixed(10) (it may also be
        # Fixed(1))

        [arg1, arg2], remaining_constraints = unify(constraints, [True, True])
        self.assertEqual(str(arg1), '10, B, int16')
        self.assertEqual(str(arg2), '10, 10, float32')
Пример #5
0
def merge(contexts):
    """
    Merge graph expression contexts into a new context, unifying their
    typing contexts under the given blaze function signature.
    """
    result = ExprContext()

    for ctx in contexts:
        result.constraints.extend(ctx.constraints)
        result.terms.update(ctx.terms)
        result.params.extend(ctx.params)

    result.constraints, _ = unify(result.constraints)
    return result
Пример #6
0
 def test_unify_ellipsis(self):
     ds1 = dshape('A, ..., B, int32')
     ds2 = dshape('M, N, ..., S, T, float32')
     [result], constraints = unify([(ds1, ds2)], [True])
     self.assertEqual(str(result), 'A, N, ..., S, B, float32')
Пример #7
0
 def test_unify_broadcasting2(self):
     ds1 = dshape('A, B, C, int32')
     ds2 = dshape('M, N, float32')
     [result], constraints = unify([(ds1, ds2)], [True])
     self.assertEqual(str(result), '1, B, C, float32')
Пример #8
0
 def test_unify_datashape_promotion(self):
     d1 = dshape('10, T1, int32')
     d2 = dshape('T2, T2, float32')
     [result], constraints = unify([(d1, d2)], [True])
     self.assertEqual(result, dshape('10, 10, float32'))
Пример #9
0
 def test_unify_implements(self):
     d1 = dshape('10, int32')
     d2 = dshape('T, A : numeric')
     [res], constraints = unify([(d1, d2)], [True])
     self.assertEqual(str(res), '10, int32')
     self.assertFalse(constraints)
Пример #10
0
 def test_unify_ellipsis2(self):
     ds1 = dshape('X, Y, float32 -> ..., float32 -> Z')
     ds2 = dshape('10, T1, int32 -> T2, T2, float32 -> R')
     [result], constraints = unify([(ds1, ds2)], [True])
     self.assertEqual(str(result), '10, Y, int32 -> T2, T2, float32 -> Z')
Пример #11
0
 def test_unify_broadcasting1(self):
     ds1 = dshape("A, B, int32")
     ds2 = dshape("K, M, N, float32")
     [result], constraints = unify([(ds1, ds2)], [True])
     self.assertEqual(str(result), "1, A, B, float32")