예제 #1
0
 def test_catches_generic_errors_with_separation(self):
     rules = (Rules().specify(1, ('List', 11)).specify(11, 'Int').specify(
         3,
         ('List', 31)).specify(31,
                               'String').instance_of(2,
                                                     1).instance_of(3, 1))
     with self.assertRaises(InferenceError):
         rules.infer()
예제 #2
0
 def test_applies_equality_constraint_from_generic(self):
     rules = (Rules().specify(1,
                              ('Fn_2', 'var_x', 'var_x', 't_none')).specify(
                                  't_none', '()').specify(
                                      10, ('Fn_2', 11, 12, 13)).specify(
                                          11, 'Int').instance_of(10, 1))
     result = rules.infer()
     self.assertEqual('Int', result.get_type_by_id(11))
     self.assertEqual('Int', result.get_type_by_id(12))
예제 #3
0
 def test_applies_recurisve_equality(self):
     rules = (Rules().specify(1, ('Pair', 11, 12)).specify(
         2, ('Pair', 21, 22)).specify(11,
                                      'Int').specify(22,
                                                     'String').equal(1, 2))
     result = rules.infer()
     expected_types = {1: ('Pair', 11, 22), 11: 'Int', 22: 'String'}
     self.assertEqual(expected_types, result.types)
     self.assertEqual({2: 1, 21: 11, 12: 22}, result.subs)
예제 #4
0
 def test_applies_generics_for_multiple_levels(self):
     rules = (Rules().specify(1,
                              ('List', 11)).specify(11, 'Int').instance_of(
                                  2, 1).instance_of(3, 1))
     expected_types = {
         1: ('List', 11),
         2: ('List', 11),
         3: ('List', 11),
         11: 'Int',
     }
     self.assertEqual(Result(expected_types, {}), rules.infer())
예제 #5
0
 def test_applies_generics_recursively(self):
     rules = (Rules().specify(1, ('Pair', 11, 12)).specify(
         2, ('Pair', 21, 22)).specify(11, 'Int').specify(
             22, 'String').instance_of(1, 2))
     expected_types = {
         1: ('Pair', 11, 12),
         2: ('Pair', 21, 22),
         11: 'Int',
         12: 'String',
         22: 'String',
     }
     self.assertEqual(Result(expected_types, {}), rules.infer())
예제 #6
0
 def test_allows_multiple_generic_instantiations(self):
     rules = (Rules().specify(1, ('List', 11)).specify(
         2,
         ('List', 21)).specify(3, ('List', 31)).specify(21, 'Int').specify(
             31, 'String').instance_of(2, 1).instance_of(3, 1))
     expected_types = {
         1: ('List', 11),
         2: ('List', 21),
         3: ('List', 31),
         21: 'Int',
         31: 'String',
     }
     self.assertEqual(Result(expected_types, {}), rules.infer())
예제 #7
0
 def test_passthrough(self):
     self.assertEqual(({}, {}), Rules().infer())
     self.assertEqual(({1: 'Int'}, {}), Rules().specify(1, 'Int').infer())
예제 #8
0
 def test_applies_substitution(self):
     self.assertEqual(Result({1: 'Int'}, {2: 1}),
                      Rules().specify(1, 'Int').equal(1, 2).infer())
     self.assertEqual(
         Result({1: 'Int'}, {2: 1}),
         Rules().specify(1, 'Int').specify(2, 'Int').equal(1, 2).infer())
예제 #9
0
 def test_ignores_reverse_relation(self):
     result = Rules().specify(1, 'Int').instance_of(1, 2).infer()
     self.assertEqual(result.types.get(2), None)
예제 #10
0
 def test_applies_generic_relations(self):
     result = Rules().specify(1, 'Int').instance_of(2, 1).infer()
     self.assertEqual(result.types.get(2), 'Int')
예제 #11
0
 def test_rejects_invalid_generic_relations(self):
     rules = (Rules().specify(1, 'Int').specify(3, 'Float').equal(
         1, 2).instance_of(2, 3))
     with self.assertRaises(InferenceError):
         rules.infer()
예제 #12
0
 def test_complicated_subs_work(self):
     rules = (Rules().specify(1, 'Int').equal(3, 4).equal(1, 5).equal(
         1, 2).equal(5, 2).equal(4, 5))
     expected = Result({1: 'Int'}, {2: 1, 3: 1, 4: 1, 5: 1})
     self.assertEqual(expected, rules.infer())
예제 #13
0
 def test_catches_incompatible_sub(self):
     rules = Rules().specify(1, 'Int').specify(2, 'Float').equal(1, 2)
     with self.assertRaises(InferenceError):
         rules.infer()
예제 #14
0
 def test_generics_with_no_types(self):
     rules = Rules()
     rules.instance_of(1, 2)
     self.assertEqual(Result({}, {}), rules.infer())
예제 #15
0
class InferenceTest(unittest.TestCase):
    def setUp(self):
        self._rules = Rules()
        self._registry = Registry()

    def test_nonpolymorphic_variable(self):
        self._registry.push_new_scope({'foo': ('var_foo_1', False)})
        v = Variable('foo')
        v_id = v.add_to_rules(self._rules, self._registry)
        self.assertEqual(Result({}, {}), self._rules.infer())

    def test_polymorphic_variable(self):
        self._registry.push_new_scope({'foo': ('var_foo_1', True)})
        v = Variable('foo')
        v_id = v.add_to_rules(self._rules, self._registry)
        self.assertEqual(Result({}, {}), self._rules.infer())

    def test_literal(self):
        l = Literal('Int', 123)
        l_id = l.add_to_rules(self._rules, self._registry)
        self.assertEqual(Result({l_id: 'Int'}, {}), self._rules.infer())

    def test_typed_expression_mismatch(self):
        te = TypedExpression('String', Literal('Int', 123))
        te_id = te.add_to_rules(self._rules, self._registry)
        with self.assertRaises(InferenceError):
            self._rules.infer()

    def test_typed_expression_match(self):
        lit = Literal('Int', 123)
        te = TypedExpression('Int', lit)
        te_id = te.add_to_rules(self._rules, self._registry)
        lit_id = self._registry.get_id_for(lit)
        result = self._rules.infer()
        self.assertEqual('Int', result.get_type_by_id(te_id))
        self.assertEqual('Int', result.get_type_by_id(lit_id))

    def test_application(self):
        self._registry.push_new_scope({'times2': ('var_times2_1', True)})
        v = Variable('times2')
        l = Literal('Int', 123)
        a = Application(v, [l])
        a_id = a.add_to_rules(self._rules, self._registry)
        v_id = self._registry.get_id_for(v)
        l_id = self._registry.get_id_for(l)

        result = self._rules.infer()
        self.assertEqual('Int', result.get_type_by_id(l_id))
        self.assertEqual(None, result.get_type_by_id(a_id))

    def test_let(self):
        l = Literal('Int', 123)
        lt = Let([('x', l)], Variable('x'))
        lt_id = lt.add_to_rules(self._rules, self._registry)

        result = self._rules.infer()
        self.assertEqual('Int', result.get_type_by_id(lt_id))

    def test_multi_let(self):
        l = Literal('Int', 123)
        lt = Let([('x', Variable('y')), ('y', l)], Variable('x'))
        lt_id = lt.add_to_rules(self._rules, self._registry)
        l_id = self._registry.get_id_for(l)

        result = self._rules.infer()
        self.assertEqual('Int', result.get_type_by_id(lt_id))

    def test_lambda_exprssion(self):
        lm = Lambda(['x'], Variable('x'))
        lmid = lm.add_to_rules(self._rules, self._registry)

        result = self._rules.infer()
        self.assertEqual(
            ('Fn_1', 'var_x_2', 'var_x_2'),
            result.get_type_by_id(lmid)
        )

    def test_let_with_lambda(self):
        ''' ML code:
        let id = \\x -> x
        in id 'foo'
        '''
        lm = Lambda(['x'], Variable('x'))
        var_id = Variable('id')
        app = Application(var_id, [Literal('String', 'foo')])
        lt = Let([('id', lm)], app)
        lt_id = lt.add_to_rules(self._rules, self._registry)

        result = self._rules.infer()
        self.assertEqual('String', result.get_type_by_id(lt_id))

    def test_polymorphism(self):
        ''' ML code:
        let id = \\x -> x
        in (id id) 123
        '''
        lm = Lambda(['x'], Variable('x'))
        app1 = Application(Variable('id'), [Variable('id')])
        app2 = Application(app1, [Literal('Int', 123)])
        lt = Let([('id', lm)], app2)
        lt_id = lt.add_to_rules(self._rules, self._registry)

        result = self._rules.infer()
        self.assertEqual('Int', result.get_type_by_id(lt_id))

    def test_if_statement(self):
        test = Literal('Bool', True)
        if_case = Literal('Int', 123)
        else_case = Literal('Int', 456)
        if_block = If(test, if_case, else_case)

        if_id = if_block.add_to_rules(self._rules, self._registry)
        result = self._rules.infer()

        self.assertEqual('Int', result.get_type_by_id(if_id))

    def test_if_statement_requires_branches_to_equal(self):
        test = Literal('Bool', True)
        if_case = Literal('Int', 123)
        else_case = Literal('Float', 456)
        if_block = If(test, if_case, else_case)

        if_id = if_block.add_to_rules(self._rules, self._registry)
        with self.assertRaises(InferenceError):
            self._rules.infer()

    def test_if_statement_requires_test_to_be_boolean(self):
        test = Literal('String', 'not a boolean')
        if_case = Literal('Int', 123)
        else_case = Literal('Int', 456)
        if_block = If(test, if_case, else_case)

        if_id = if_block.add_to_rules(self._rules, self._registry)
        with self.assertRaises(InferenceError):
            self._rules.infer()

    def test_mutual_recursion(self):
        '''
        Equivalent ML:

        let-rec f = if True then 123 else g
                g = f
        in f
        '''
        test = Literal('Bool', True)
        if_case = Literal('Int', 123)
        else_case = Application(Variable('g'), [])
        if_block = If(test, if_case, else_case)
        f_func = Lambda([], if_block)

        g_body = Application(Variable('f'), [])
        g_func = Lambda([], g_body)

        let_body = Variable('f')
        let_expr = Let([('f', f_func), ('g', g_func)], let_body)

        let_id = let_expr.add_to_rules(self._rules, self._registry)
        result = self._rules.infer()
        self.assertEqual(('Fn_0', 'Int'), result.get_full_type_by_id(let_id))

    def test_generic_mutual_recursion(self):
        '''
        Equivalent ML:

        let-rec f x = if True then x else g x
                g y = f y
        in g
        '''
        test = Literal('Bool', True)
        if_case = Variable('x')
        else_case = Application(Variable('g'), [Variable('x')])
        if_block = If(test, if_case, else_case)
        f_func = Lambda(['x'], if_block)

        g_body = Application(Variable('f'), [Variable('y')])
        g_func = Lambda(['y'], g_body)

        let_body = Variable('f')
        let_expr = Let([('f', f_func), ('g', g_func)], let_body)

        let_id = let_expr.add_to_rules(self._rules, self._registry)
        result = self._rules.infer()
        self.assertEqual(('Fn_1', 'a0', 'a0'), result.get_full_type_by_id(let_id))
예제 #16
0
 def setUp(self):
     self._rules = Rules()
     self._registry = Registry()
예제 #17
0
 def test_accepts_circular_generic_relations(self):
     rules = Rules().specify(1, 'Int').instance_of(1, 2).instance_of(2, 1)
     self.assertEqual(rules.infer().types.get(2), 'Int')