def test_catches_generic_errors_with_separation(self): rules = (Rules().specify(1, ('List', 11)).specify(11, 'Int').specify( 3, ('List', 31)).specify(31, 'String').instance_of(2, 1).instance_of(3, 1)) with self.assertRaises(InferenceError): rules.infer()
def test_applies_equality_constraint_from_generic(self): rules = (Rules().specify(1, ('Fn_2', 'var_x', 'var_x', 't_none')).specify( 't_none', '()').specify( 10, ('Fn_2', 11, 12, 13)).specify( 11, 'Int').instance_of(10, 1)) result = rules.infer() self.assertEqual('Int', result.get_type_by_id(11)) self.assertEqual('Int', result.get_type_by_id(12))
def test_applies_recurisve_equality(self): rules = (Rules().specify(1, ('Pair', 11, 12)).specify( 2, ('Pair', 21, 22)).specify(11, 'Int').specify(22, 'String').equal(1, 2)) result = rules.infer() expected_types = {1: ('Pair', 11, 22), 11: 'Int', 22: 'String'} self.assertEqual(expected_types, result.types) self.assertEqual({2: 1, 21: 11, 12: 22}, result.subs)
def test_applies_generics_for_multiple_levels(self): rules = (Rules().specify(1, ('List', 11)).specify(11, 'Int').instance_of( 2, 1).instance_of(3, 1)) expected_types = { 1: ('List', 11), 2: ('List', 11), 3: ('List', 11), 11: 'Int', } self.assertEqual(Result(expected_types, {}), rules.infer())
def test_applies_generics_recursively(self): rules = (Rules().specify(1, ('Pair', 11, 12)).specify( 2, ('Pair', 21, 22)).specify(11, 'Int').specify( 22, 'String').instance_of(1, 2)) expected_types = { 1: ('Pair', 11, 12), 2: ('Pair', 21, 22), 11: 'Int', 12: 'String', 22: 'String', } self.assertEqual(Result(expected_types, {}), rules.infer())
def test_allows_multiple_generic_instantiations(self): rules = (Rules().specify(1, ('List', 11)).specify( 2, ('List', 21)).specify(3, ('List', 31)).specify(21, 'Int').specify( 31, 'String').instance_of(2, 1).instance_of(3, 1)) expected_types = { 1: ('List', 11), 2: ('List', 21), 3: ('List', 31), 21: 'Int', 31: 'String', } self.assertEqual(Result(expected_types, {}), rules.infer())
def test_passthrough(self): self.assertEqual(({}, {}), Rules().infer()) self.assertEqual(({1: 'Int'}, {}), Rules().specify(1, 'Int').infer())
def test_applies_substitution(self): self.assertEqual(Result({1: 'Int'}, {2: 1}), Rules().specify(1, 'Int').equal(1, 2).infer()) self.assertEqual( Result({1: 'Int'}, {2: 1}), Rules().specify(1, 'Int').specify(2, 'Int').equal(1, 2).infer())
def test_ignores_reverse_relation(self): result = Rules().specify(1, 'Int').instance_of(1, 2).infer() self.assertEqual(result.types.get(2), None)
def test_applies_generic_relations(self): result = Rules().specify(1, 'Int').instance_of(2, 1).infer() self.assertEqual(result.types.get(2), 'Int')
def test_rejects_invalid_generic_relations(self): rules = (Rules().specify(1, 'Int').specify(3, 'Float').equal( 1, 2).instance_of(2, 3)) with self.assertRaises(InferenceError): rules.infer()
def test_complicated_subs_work(self): rules = (Rules().specify(1, 'Int').equal(3, 4).equal(1, 5).equal( 1, 2).equal(5, 2).equal(4, 5)) expected = Result({1: 'Int'}, {2: 1, 3: 1, 4: 1, 5: 1}) self.assertEqual(expected, rules.infer())
def test_catches_incompatible_sub(self): rules = Rules().specify(1, 'Int').specify(2, 'Float').equal(1, 2) with self.assertRaises(InferenceError): rules.infer()
def test_generics_with_no_types(self): rules = Rules() rules.instance_of(1, 2) self.assertEqual(Result({}, {}), rules.infer())
class InferenceTest(unittest.TestCase): def setUp(self): self._rules = Rules() self._registry = Registry() def test_nonpolymorphic_variable(self): self._registry.push_new_scope({'foo': ('var_foo_1', False)}) v = Variable('foo') v_id = v.add_to_rules(self._rules, self._registry) self.assertEqual(Result({}, {}), self._rules.infer()) def test_polymorphic_variable(self): self._registry.push_new_scope({'foo': ('var_foo_1', True)}) v = Variable('foo') v_id = v.add_to_rules(self._rules, self._registry) self.assertEqual(Result({}, {}), self._rules.infer()) def test_literal(self): l = Literal('Int', 123) l_id = l.add_to_rules(self._rules, self._registry) self.assertEqual(Result({l_id: 'Int'}, {}), self._rules.infer()) def test_typed_expression_mismatch(self): te = TypedExpression('String', Literal('Int', 123)) te_id = te.add_to_rules(self._rules, self._registry) with self.assertRaises(InferenceError): self._rules.infer() def test_typed_expression_match(self): lit = Literal('Int', 123) te = TypedExpression('Int', lit) te_id = te.add_to_rules(self._rules, self._registry) lit_id = self._registry.get_id_for(lit) result = self._rules.infer() self.assertEqual('Int', result.get_type_by_id(te_id)) self.assertEqual('Int', result.get_type_by_id(lit_id)) def test_application(self): self._registry.push_new_scope({'times2': ('var_times2_1', True)}) v = Variable('times2') l = Literal('Int', 123) a = Application(v, [l]) a_id = a.add_to_rules(self._rules, self._registry) v_id = self._registry.get_id_for(v) l_id = self._registry.get_id_for(l) result = self._rules.infer() self.assertEqual('Int', result.get_type_by_id(l_id)) self.assertEqual(None, result.get_type_by_id(a_id)) def test_let(self): l = Literal('Int', 123) lt = Let([('x', l)], Variable('x')) lt_id = lt.add_to_rules(self._rules, self._registry) result = self._rules.infer() self.assertEqual('Int', result.get_type_by_id(lt_id)) def test_multi_let(self): l = Literal('Int', 123) lt = Let([('x', Variable('y')), ('y', l)], Variable('x')) lt_id = lt.add_to_rules(self._rules, self._registry) l_id = self._registry.get_id_for(l) result = self._rules.infer() self.assertEqual('Int', result.get_type_by_id(lt_id)) def test_lambda_exprssion(self): lm = Lambda(['x'], Variable('x')) lmid = lm.add_to_rules(self._rules, self._registry) result = self._rules.infer() self.assertEqual( ('Fn_1', 'var_x_2', 'var_x_2'), result.get_type_by_id(lmid) ) def test_let_with_lambda(self): ''' ML code: let id = \\x -> x in id 'foo' ''' lm = Lambda(['x'], Variable('x')) var_id = Variable('id') app = Application(var_id, [Literal('String', 'foo')]) lt = Let([('id', lm)], app) lt_id = lt.add_to_rules(self._rules, self._registry) result = self._rules.infer() self.assertEqual('String', result.get_type_by_id(lt_id)) def test_polymorphism(self): ''' ML code: let id = \\x -> x in (id id) 123 ''' lm = Lambda(['x'], Variable('x')) app1 = Application(Variable('id'), [Variable('id')]) app2 = Application(app1, [Literal('Int', 123)]) lt = Let([('id', lm)], app2) lt_id = lt.add_to_rules(self._rules, self._registry) result = self._rules.infer() self.assertEqual('Int', result.get_type_by_id(lt_id)) def test_if_statement(self): test = Literal('Bool', True) if_case = Literal('Int', 123) else_case = Literal('Int', 456) if_block = If(test, if_case, else_case) if_id = if_block.add_to_rules(self._rules, self._registry) result = self._rules.infer() self.assertEqual('Int', result.get_type_by_id(if_id)) def test_if_statement_requires_branches_to_equal(self): test = Literal('Bool', True) if_case = Literal('Int', 123) else_case = Literal('Float', 456) if_block = If(test, if_case, else_case) if_id = if_block.add_to_rules(self._rules, self._registry) with self.assertRaises(InferenceError): self._rules.infer() def test_if_statement_requires_test_to_be_boolean(self): test = Literal('String', 'not a boolean') if_case = Literal('Int', 123) else_case = Literal('Int', 456) if_block = If(test, if_case, else_case) if_id = if_block.add_to_rules(self._rules, self._registry) with self.assertRaises(InferenceError): self._rules.infer() def test_mutual_recursion(self): ''' Equivalent ML: let-rec f = if True then 123 else g g = f in f ''' test = Literal('Bool', True) if_case = Literal('Int', 123) else_case = Application(Variable('g'), []) if_block = If(test, if_case, else_case) f_func = Lambda([], if_block) g_body = Application(Variable('f'), []) g_func = Lambda([], g_body) let_body = Variable('f') let_expr = Let([('f', f_func), ('g', g_func)], let_body) let_id = let_expr.add_to_rules(self._rules, self._registry) result = self._rules.infer() self.assertEqual(('Fn_0', 'Int'), result.get_full_type_by_id(let_id)) def test_generic_mutual_recursion(self): ''' Equivalent ML: let-rec f x = if True then x else g x g y = f y in g ''' test = Literal('Bool', True) if_case = Variable('x') else_case = Application(Variable('g'), [Variable('x')]) if_block = If(test, if_case, else_case) f_func = Lambda(['x'], if_block) g_body = Application(Variable('f'), [Variable('y')]) g_func = Lambda(['y'], g_body) let_body = Variable('f') let_expr = Let([('f', f_func), ('g', g_func)], let_body) let_id = let_expr.add_to_rules(self._rules, self._registry) result = self._rules.infer() self.assertEqual(('Fn_1', 'a0', 'a0'), result.get_full_type_by_id(let_id))
def setUp(self): self._rules = Rules() self._registry = Registry()
def test_accepts_circular_generic_relations(self): rules = Rules().specify(1, 'Int').instance_of(1, 2).instance_of(2, 1) self.assertEqual(rules.infer().types.get(2), 'Int')