def test_case_for_transposes(batch, beta, trans_lhs, trans_rhs): """ Create a list of strings corresponding to separate lines in the full test case. The output contains headers, includes, setup and all the tests for the test case. """ scriptname = os.path.basename(__file__) test_case = TEST_CASE_TPL.format(batch=batch, beta=beta, trans_lhs=trans_lhs, trans_rhs=trans_rhs) output = [ helpers.get_license(), helpers.get_dont_modify_comment(scriptname=scriptname), INCLUDES, DATA_TYPES, TYPED_TEST_SUITE_DECL_TPL.format( test_case=test_case, trans_lhs=helpers.to_lower_case_str(trans_lhs), trans_rhs=helpers.to_lower_case_str(trans_rhs)), ] in_sizes = get_input_sizes() for m, k, n in itertools.product(in_sizes, in_sizes, in_sizes): output.extend( get_test_lines(batch, m, k, n, beta, trans_lhs, trans_rhs)) return output
def output_for_test_case(test_case): """ Create a list of strings corresponding to separate lines in the full test case. The output contains headers, includes, setup and all the tests for the test case. """ scriptname = os.path.basename(__file__) camel_case_type = helpers.to_camel_case(test_case.test_type) test_case_name = TEST_CASE_TPL.format( test_type=camel_case_type, direction=helpers.to_camel_case(test_case.direction), operation=helpers.to_camel_case(test_case.operation)) output = [ helpers.get_license(), helpers.get_dont_modify_comment(scriptname=scriptname), INCLUDES, TYPED_TEST_CASE_DECL_TPL.format( test_case=test_case_name, direction=DIRECTION_MAP[test_case.direction], operation=OPERATION_MAP[test_case.operation]), ] for test_params in test_params_for_test_case(test_case): output.extend(get_test_lines(test_case, test_params)) output.append("\n") return output
def get_initial_boilerplate(): """ Get the boilerplate for the top of the test file. """ scriptname = os.path.basename(__file__) return [ helpers.get_license(), helpers.get_dont_modify_comment(scriptname=scriptname), INCLUDES, DATA_TYPES, ]
def transpose_test_case(n_dimensions): """ Create a list of strings corresponding to separate lines in the full test case. The output contains headers, includes, setup and all the tests for the test case. """ scriptname = os.path.basename(__file__) test_case = TEST_CASE_TPL.format(n_dimensions=n_dimensions) output = [ helpers.get_license(), helpers.get_dont_modify_comment(scriptname=scriptname), INCLUDES, DATA_TYPES, TYPED_TEST_SUITE_DECL_TPL.format(test_case=test_case), ] for in_shape, permutation in test_cases(n_dimensions): output.extend(get_test_lines(in_shape, permutation)) return output
def output_for_test_case(test_case): """ Create a list of strings corresponding to separate lines in the full test case. The output contains headers, includes, setup and all the tests for the test case. """ scriptname = os.path.basename(__file__) camel_case_type = helpers.to_camel_case(test_case.test_type) test_case_name = TEST_CASE_TPL.format(test_type=camel_case_type, window=test_case.window, stride=test_case.stride) output = [ helpers.get_license(), helpers.get_dont_modify_comment(scriptname=scriptname), INCLUDES, DATA_TYPES, TYPED_TEST_SUITE_DECL_TPL.format(test_case=test_case_name, window=test_case.window, stride=test_case.stride) ] for test_params in test_params_for_test_case(test_case): output.extend(get_test_lines(test_case, test_params)) return output