示例#1
0
def test_case_for_transposes(batch, beta, trans_lhs, trans_rhs):
    """
    Create a list of strings corresponding to separate lines in the full test
    case. The output contains headers, includes, setup and all the tests for
    the test case.
    """
    scriptname = os.path.basename(__file__)
    test_case = TEST_CASE_TPL.format(batch=batch,
                                     beta=beta,
                                     trans_lhs=trans_lhs,
                                     trans_rhs=trans_rhs)
    output = [
        helpers.get_license(),
        helpers.get_dont_modify_comment(scriptname=scriptname),
        INCLUDES,
        DATA_TYPES,
        TYPED_TEST_SUITE_DECL_TPL.format(
            test_case=test_case,
            trans_lhs=helpers.to_lower_case_str(trans_lhs),
            trans_rhs=helpers.to_lower_case_str(trans_rhs)),
    ]
    in_sizes = get_input_sizes()
    for m, k, n in itertools.product(in_sizes, in_sizes, in_sizes):
        output.extend(
            get_test_lines(batch, m, k, n, beta, trans_lhs, trans_rhs))
    return output
示例#2
0
def output_for_test_case(test_case):
    """
    Create a list of strings corresponding to separate lines in the full test
    case. The output contains headers, includes, setup and all the tests for
    the test case.
    """
    scriptname = os.path.basename(__file__)
    camel_case_type = helpers.to_camel_case(test_case.test_type)
    test_case_name = TEST_CASE_TPL.format(
        test_type=camel_case_type,
        direction=helpers.to_camel_case(test_case.direction),
        operation=helpers.to_camel_case(test_case.operation))
    output = [
        helpers.get_license(),
        helpers.get_dont_modify_comment(scriptname=scriptname),
        INCLUDES,
        TYPED_TEST_CASE_DECL_TPL.format(
            test_case=test_case_name,
            direction=DIRECTION_MAP[test_case.direction],
            operation=OPERATION_MAP[test_case.operation]),
    ]

    for test_params in test_params_for_test_case(test_case):
        output.extend(get_test_lines(test_case, test_params))
    output.append("\n")
    return output
示例#3
0
def get_initial_boilerplate():
    """ Get the boilerplate for the top of the test file. """
    scriptname = os.path.basename(__file__)
    return [
        helpers.get_license(),
        helpers.get_dont_modify_comment(scriptname=scriptname),
        INCLUDES,
        DATA_TYPES,
    ]
示例#4
0
def transpose_test_case(n_dimensions):
    """
    Create a list of strings corresponding to separate lines in the full test
    case. The output contains headers, includes, setup and all the tests for
    the test case.
    """
    scriptname = os.path.basename(__file__)
    test_case = TEST_CASE_TPL.format(n_dimensions=n_dimensions)
    output = [
        helpers.get_license(),
        helpers.get_dont_modify_comment(scriptname=scriptname),
        INCLUDES,
        DATA_TYPES,
        TYPED_TEST_SUITE_DECL_TPL.format(test_case=test_case),
    ]
    for in_shape, permutation in test_cases(n_dimensions):
        output.extend(get_test_lines(in_shape, permutation))
    return output
示例#5
0
def output_for_test_case(test_case):
    """
    Create a list of strings corresponding to separate lines in the full test
    case. The output contains headers, includes, setup and all the tests for
    the test case.
    """
    scriptname = os.path.basename(__file__)
    camel_case_type = helpers.to_camel_case(test_case.test_type)
    test_case_name = TEST_CASE_TPL.format(test_type=camel_case_type,
                                          window=test_case.window,
                                          stride=test_case.stride)
    output = [
        helpers.get_license(),
        helpers.get_dont_modify_comment(scriptname=scriptname), INCLUDES,
        DATA_TYPES,
        TYPED_TEST_SUITE_DECL_TPL.format(test_case=test_case_name,
                                         window=test_case.window,
                                         stride=test_case.stride)
    ]
    for test_params in test_params_for_test_case(test_case):
        output.extend(get_test_lines(test_case, test_params))
    return output