def test_higher_function(self): f = Function('map', lambda f, xs: [f(x) for x in xs], (FunctionType(INT, INT), LIST), LIST) bar = Function('3x', lambda x: 3 * x, INT, INT) self.assertEqual(f(bar, ListValue([1, 2, 3])), ListValue([3, 6, 9])) self.assertEqual(f.type, FunctionType((FunctionType(INT, INT), LIST), LIST)) self.assertEqual(str(f.type), 'F((F(INT, INT), LIST), LIST)')
def test_function(self): f = Function('foo', lambda x: x + 1, INT, INT) ftype = FunctionType(INT, INT) self.assertEqual(f.type, FunctionType(INT, INT)) self.assertEqual(str(f.type), 'F(INT, INT)') self.assertEqual(f(IntValue(1)), IntValue(2)) self.assertEqual(str(f), 'foo') self.assertEqual(f.name, 'foo')
LMAX = Function('max', max, (INT, INT), INT) # higher order functions def _scan1l(f, xs): ys = [0] * len(xs) for i, x in enumerate(xs): if i: ys[i] = f(ys[i - 1], x) else: ys[i] = x return ys MAP = Function('MAP', lambda f, xs: [f(x) for x in xs], (FunctionType(INT, INT), LIST), LIST) FILTER = Function('FILTER', lambda f, xs: [x for x in xs if f(x)], (FunctionType(INT, BOOL), LIST), LIST) COUNT = Function('COUNT', lambda f, xs: len([x for x in xs if f(x)]), (FunctionType(INT, BOOL), LIST), INT) SCAN1L = Function('SCAN1L', _scan1l, (FunctionType((INT, INT), INT), LIST), LIST) ZIPWITH = Function('ZIPWITH', lambda f, xs, ys: [f(x, y) for x, y in zip(xs, ys)], (FunctionType((INT, INT), INT), LIST, LIST), LIST) LAMBDAS = [ PLUS1, MINUS1, TIMES2, DIV2,
LTIMES = Function('*', lambda x, y: x * y, (INT, INT), INT) #LDIV = Function('/', lambda x, y: x / y if y else None, (INT, INT), INT) LMIN = Function('min', min, (INT, INT), INT) LMAX = Function('max', max, (INT, INT), INT) # higher order functions def _scan1l(f, xs): ys = [0] * len(xs) for i, x in enumerate(xs): if i: ys[i] = f(ys[i - 1], x) else: ys[i] = x return ys MAP = Function('MAP', lambda f, xs: [f(x) for x in xs], (FunctionType(INT, INT), LIST), LIST) FILTER = Function('FILTER', lambda f, xs: [x for x in xs if f(x)], (FunctionType(INT, BOOL), LIST), LIST) COUNT = Function('COUNT', lambda f, xs: len([x for x in xs if f(x)]), (FunctionType(INT, BOOL), LIST), INT) SCAN1L = Function('SCAN1L', _scan1l, (FunctionType((INT, INT), INT), LIST), LIST) ZIPWITH = Function('ZIPWITH', lambda f, xs, ys: [f(x, y) for x, y in zip(xs, ys)], (FunctionType((INT, INT), INT), LIST, LIST), LIST) LAMBDAS = [ PLUS1, MINUS1, TIMES2, DIV2, TIMESNEG1, POW2, TIMES3, DIV3, TIMES4,
def __init__(self, name, f, input_type, output_type): super(Function, self).__init__(f, FunctionType(input_type, output_type)) self.name = name
def test_function_type(self): ft = FunctionType(INT, INT) self.assertEqual(str(ft), 'F(INT, INT)') self.assertEqual(ft.input_types, (INT,))