def test_module_attrs(self): m = dex.Module(dedent(""" x = 2.5 y = [2, 3, 4] """)) assert str(m.x) == "2.5" assert str(m.y) == "[2, 3, 4]"
def test_polymorphic_array_1d(self): m = dex.Module( dedent(""" def addTwo {n} (x: (Fin n)=>Float) : (Fin n)=>Float = for i. x.i + 2.0 """)) check_atom(m.addTwo, lambda x: x + 2, [(np.arange(l, dtype=np.float32), ) for l in (2, 5, 10)])
def test_function_call(): m = dex.Module(dedent(""" def addOne (x: Float) : Float = x + 1.0 """)) x = dex.eval("2.5") y = dex.eval("[2, 3, 4]") assert str(m.addOne(x)) == "3.5" assert str(m.sum(y)) == "9"
def test_polymorphic_array_2d(self): m = dex.Module( dedent(""" def myTranspose {n m} (x : (Fin n)=>(Fin m)=>Float) : (Fin m)=>(Fin n)=>Float = for i j. x.j.i """)) check_atom(m.myTranspose, lambda x: x.T, [(np.arange(a * b, dtype=np.float32).reshape((a, b)), ) for a, b in it.product((2, 5, 10), repeat=2)])
# Copyright 2020 Google LLC # # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd import dex m = dex.Module(""" x = 2.5 y = [2, 3, 4] """) print(m.x) print(m.y) print(int(m.x))