def main(functions=(), j=1): """ Generates expression tests. For every signature prim, rev and fwd instantiations are tested (with all scalars of type double/var/fvar<double>). Tests check the following: - signatures can be compiled with expressions arguments - results when given expressions are same as when given plain matrices (including derivatives) - functions evaluate expressions at most once :param functions: functions to generate tests for. This can contain names of functions already supported by stanc3, full function signatures or file names of files containing any of the previous two. Default: all signatures supported by stanc3 :param j: number of files to split tests in """ tests = [] functions, signatures = handle_function_list(functions) signatures = set(signatures) if not functions and not signatures: default_checks = True signatures |= set(get_signatures()) else: for signature in get_signatures(): sp = SignatureParser(signature) if sp.function_name in functions: signatures.add(signature) default_checks = False for n, signature in enumerate(signatures): for overload in ("Prim", "Rev", "Fwd"): sp = SignatureParser(signature) # skip ignored signatures if running the default checks if default_checks and sp.is_ignored(): continue # skip signatures without inputs that can be eigen types if not sp.has_eigen_compatible_arg(): continue # skip functions in forward mode with no forward mode overload, same for reverse mode if overload == "Fwd" and not sp.is_fwd_compatible(): continue if overload == "Rev" and not sp.is_rev_compatible(): continue is_reverse_mode = overload == "Rev" and not sp.returns_int() cg = CodeGenerator() # Generate two sets of arguments, one will be expressions one will not arg_list_expression_base = cg.build_arguments( sp, sp.number_arguments() * [overload], size=1) arg_list_expression = [ cg.convert_to_expression(arg, size=1) if arg.is_eigen_compatible() else arg for arg in arg_list_expression_base ] arg_list_no_expression = cg.build_arguments(sp, sp.number_arguments() * [overload], size=1) # Check results with expressions and without are the same cpp_function_name = "stan::math::" + sp.function_name result = cg.function_call_assign(cpp_function_name, *arg_list_expression) result_no_expression = cg.function_call_assign( cpp_function_name, *arg_list_no_expression) cg.expect_eq(result, result_no_expression) # Check that expressions evaluated only once for arg in arg_list_expression: if arg.is_expression(): cg.expect_leq_one(arg.counter) # If reverse mode, check the adjoints if is_reverse_mode: summed_result = cg.recursive_sum(result) summed_result_no_expression = cg.recursive_sum( result_no_expression) sum_of_sums = cg.add(summed_result, summed_result_no_expression) cg.grad(sum_of_sums) for arg, arg_no_expression in zip(arg_list_no_expression, arg_list_expression): cg.expect_adj_eq(arg, arg_no_expression) cg.recover_memory() tests.append( test_code_template.format( overload=overload, comment=signature.strip(), test_name=sp.function_name + repr(n), code=cg.cpp(), )) save_tests_in_files(j, tests)
class CodeGeneratorTest(unittest.TestCase): def setUp(self): self.add = SignatureParser("add(real, vector) => vector") self.int_var = IntVariable("myint", 5) self.real_var1 = RealVariable("Rev", "myreal1", 0.5) self.real_var2 = RealVariable("Rev", "myreal2", 0.5) self.matrix_var = MatrixVariable("Rev", "mymatrix", "matrix", 2, 0.5) self.cg = CodeGenerator() def test_prim_prim(self): self.cg.build_arguments(self.add, ["Prim", "Prim"], 1) self.assertEqual( self.cg.cpp(), """double real0 = 0.4; auto matrix1 = stan::test::make_arg<Eigen::Matrix<double, Eigen::Dynamic, 1>>(0.4, 1);""" ) def test_prim_rev(self): self.cg.build_arguments(self.add, ["Prim", "Rev"], 1) self.assertEqual( self.cg.cpp(), """double real0 = 0.4; auto matrix1 = stan::test::make_arg<Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1>>(0.4, 1);""" ) def test_rev_rev(self): self.cg.build_arguments(self.add, ["Rev", "Rev"], 1) self.assertEqual( self.cg.cpp(), """stan::math::var real0 = 0.4; auto matrix1 = stan::test::make_arg<Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1>>(0.4, 1);""" ) def test_size(self): self.cg.build_arguments(self.add, ["Rev", "Rev"], 2) self.assertEqual( self.cg.cpp(), """stan::math::var real0 = 0.4; auto matrix1 = stan::test::make_arg<Eigen::Matrix<stan::math::var, Eigen::Dynamic, 1>>(0.4, 2);""" ) def test_add(self): self.cg.add(self.real_var1, self.real_var2) self.assertEqual( self.cg.cpp(), "auto sum_of_sums0 = stan::math::eval(stan::math::add(myreal1,myreal2));" ) def test_convert_to_expression(self): self.cg.convert_to_expression(self.matrix_var) self.assertEqual( self.cg.cpp(), """int mymatrix_expr0_counter = 0; stan::test::counterOp<stan::math::var> mymatrix_expr0_counter_op(&mymatrix_expr0_counter); auto mymatrix_expr0 = mymatrix.block(0,0,2,2).unaryExpr(mymatrix_expr0_counter_op);""" ) def test_expect_adj_eq(self): self.cg.expect_adj_eq(self.real_var1, self.real_var2) self.assertEqual(self.cg.cpp(), "stan::test::expect_adj_eq(myreal1,myreal2);") def test_expect_eq(self): self.cg.expect_eq(self.real_var1, self.real_var2) self.assertEqual(self.cg.cpp(), "EXPECT_STAN_EQ(myreal1,myreal2);") def test_expect_leq_one(self): self.cg.expect_leq_one(self.int_var) self.assertEqual(self.cg.cpp(), """int int0 = 1; EXPECT_LE(myint,int0);""") def test_function_call_assign(self): self.cg.function_call_assign("stan::math::add", self.real_var1, self.real_var2) self.assertEqual( self.cg.cpp(), "auto result0 = stan::math::eval(stan::math::add(myreal1,myreal2));" ) def test_grad(self): self.cg.grad(self.real_var1) self.assertEqual(self.cg.cpp(), "stan::test::grad(myreal1);") def test_recover_memory(self): self.cg.recover_memory() self.assertEqual(self.cg.cpp(), "stan::math::recover_memory();") def test_recursive_sum(self): self.cg.recursive_sum(self.real_var1) self.assertEqual( self.cg.cpp(), "auto summed_result0 = stan::math::eval(stan::test::recursive_sum(myreal1));" ) def test_to_var_value(self): self.cg.to_var_value(self.matrix_var) self.assertEqual( self.cg.cpp(), "auto mymatrix_varmat0 = stan::math::eval(stan::math::to_var_value(mymatrix));" )