예제 #1
0
파일: cpp.py 프로젝트: qser/CLBlast
def performance_test(routine, level_string):
    """Generates the body of a performance test for a specific routine"""
    result = ""
    result += "#include \"test/performance/client.hpp\"" + NL
    result += "#include \"test/routines/level" + level_string + "/x" + routine.lowercase_name(
    ) + ".hpp\"" + NL + NL
    result += "// Main function (not within the clblast namespace)" + NL
    result += "int main(int argc, char *argv[]) {" + NL
    result += "  const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);" + NL
    default = convert.precision_to_full_name(
        routine.flavours[0].precision_name)
    result += "  switch(clblast::GetPrecision(command_line_args, clblast::Precision::k" + default + ")) {" + NL
    for precision in ["H", "S", "D", "C", "Z"]:
        result += "    case clblast::Precision::k" + convert.precision_to_full_name(
            precision) + ":"
        found = False
        for flavour in routine.flavours:
            if flavour.precision_name == precision:
                result += NL + "      clblast::RunClient<clblast::TestX" + routine.plain_name(
                ) + flavour.test_template()
                result += ">(argc, argv); break;" + NL
                found = True
        if not found:
            result += " throw std::runtime_error(\"Unsupported precision mode\");" + NL
    result += "  }" + NL
    result += "  return 0;" + NL
    result += "}" + NL
    return result
예제 #2
0
파일: cpp.py 프로젝트: gpu/CLBlast
def performance_test(routine, level_string):
    """Generates the body of a performance test for a specific routine"""
    result = ""
    result += "#include \"test/performance/client.hpp\"" + NL
    result += "#include \"test/routines/level" + level_string + "/x" + routine.lowercase_name() + ".hpp\"" + NL + NL
    result += "// Main function (not within the clblast namespace)" + NL
    result += "int main(int argc, char *argv[]) {" + NL
    result += "  const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);" + NL
    default = convert.precision_to_full_name(routine.flavours[0].precision_name)
    result += "  switch(clblast::GetPrecision(command_line_args, clblast::Precision::k" + default + ")) {" + NL
    for precision in ["H", "S", "D", "C", "Z"]:
        result += "    case clblast::Precision::k" + convert.precision_to_full_name(precision) + ":"
        found = False
        for flavour in routine.flavours:
            if flavour.precision_name == precision:
                extra_template_argument = "0, " if routine.name == "gemm" and not routine.batched else ""
                result += NL + "      clblast::RunClient<clblast::TestX" + routine.plain_name()
                result += flavour.test_template(extra_template_argument)
                result += ">(argc, argv); break;" + NL
                found = True
        if not found:
            result += " throw std::runtime_error(\"Unsupported precision mode\");" + NL
    result += "  }" + NL
    result += "  return 0;" + NL
    result += "}" + NL
    return result