def test_apply_loop_invariant_optimisation_integer(): variables = {'v': Variable('v', Unit(1), scalar=False), 'N': Constant('N', Unit(1), 10)} statements = [Statement('v', '=', 'v % (2*3*N)', '', np.float32)] scalar, vector = apply_loop_invariant_optimisations(statements, variables, np.float64) # The optimisation should not pull out 2*N assert len(scalar) == 0
def test_apply_loop_invariant_optimisation_integer(): variables = { 'v': Variable('v', Unit(1), scalar=False), 'N': Constant('N', Unit(1), 10) } statements = [Statement('v', '=', 'v % (2*3*N)', '', np.float32)] scalar, vector = apply_loop_invariant_optimisations( statements, variables, np.float64) # The optimisation should not pull out 2*N assert len(scalar) == 0
def test_apply_loop_invariant_optimisation(): variables = {'v': Variable('v', Unit(1), scalar=False), 'w': Variable('w', Unit(1), scalar=False), 'dt': Constant('dt', second, 0.1*ms), 'tau': Constant('tau', second, 10*ms), 'exp': DEFAULT_FUNCTIONS['exp']} statements = [Statement('v', '=', 'dt*w*exp(-dt/tau)/tau + v*exp(-dt/tau)', '', np.float32), Statement('w', '=', 'w*exp(-dt/tau)', '', np.float32)] scalar, vector = apply_loop_invariant_optimisations(statements, variables, np.float64) # The optimisation should pull out exp(-dt / tau) assert len(scalar) == 1 assert scalar[0].dtype == np.float64 # We asked for this dtype above assert scalar[0].var == '_lio_const_1' assert len(vector) == 2 assert all('_lio_const_1' in stmt.expr for stmt in vector)
def test_apply_loop_invariant_optimisation(): variables = { 'v': Variable('v', Unit(1), scalar=False), 'w': Variable('w', Unit(1), scalar=False), 'dt': Constant('dt', second, 0.1 * ms), 'tau': Constant('tau', second, 10 * ms), 'exp': DEFAULT_FUNCTIONS['exp'] } statements = [ Statement('v', '=', 'dt*w*exp(-dt/tau)/tau + v*exp(-dt/tau)', '', np.float32), Statement('w', '=', 'w*exp(-dt/tau)', '', np.float32) ] scalar, vector = apply_loop_invariant_optimisations( statements, variables, np.float64) # The optimisation should pull out exp(-dt / tau) assert len(scalar) == 1 assert scalar[0].dtype == np.float64 # We asked for this dtype above assert scalar[0].var == '_lio_const_1' assert len(vector) == 2 assert all('_lio_const_1' in stmt.expr for stmt in vector)