def test_get_symbols_omits_error_nodes():
    with open(get_test_path(os.path.join("mathml",
                                         "x_plus_error.xml"))) as mathml_file:
        mathml = mathml_file.read()
        symbols = get_symbols(mathml)

    assert len(symbols) == 1

    x = list(filter(lambda s: s.characters == [0], symbols))[0]
    assert x.mathml.strip() == "<mi>x</mi>"
def test_parse_consecutive_mi():
    with open(get_test_path(os.path.join("mathml",
                                         "relu.xml"))) as mathml_file:
        mathml = mathml_file.read()
        symbols = get_symbols(mathml)

    assert len(symbols) == 1

    relu = symbols[0]
    assert len(relu.children) == 0
    assert len(relu.characters) == 4
    assert relu.mathml == "<mi>ReLU</mi>"
def test_get_symbols_for_subscript():
    with open(get_test_path(os.path.join("mathml",
                                         "x_sub_i.xml"))) as mathml_file:
        mathml = mathml_file.read()
        symbols = get_symbols(mathml)

    assert len(symbols) == 3
    x_sub_i = list(filter(lambda s: "msub" in s.mathml, symbols))[0]
    assert re.search(r"<msub>\s*<mi>x</mi>\s*<mi>i</mi>\s*</msub>",
                     x_sub_i.mathml)
    assert len(x_sub_i.characters) == 2
    assert 0 in x_sub_i.characters
    assert 1 in x_sub_i.characters
def test_get_symbols_for_characters():
    with open(get_test_path(os.path.join("mathml",
                                         "x_plus_y.xml"))) as mathml_file:
        mathml = mathml_file.read()
        symbols = get_symbols(mathml)

    assert len(symbols) == 2

    x = list(filter(lambda s: s.characters == [0], symbols))[0]
    assert x.mathml.strip() == "<mi>x</mi>"

    y = list(filter(lambda s: s.characters == [2], symbols))[0]
    assert y.mathml.strip() == "<mi>y</mi>"
def test_get_symbol_sub_super_with_numeric_child():
    with open(get_test_path(os.path.join(
            "mathml", "x_sub_four_sub_three.xml"))) as mathml_file:
        mathml = mathml_file.read()
        symbols = get_symbols(mathml)

    assert len(symbols) == 3
    x_sub_four_sub_three = list(
        filter(lambda s: "msub" in s.mathml and "x" in s.mathml, symbols))[0]
    x = list(filter(lambda s: s.mathml == "<mi>x</mi>", symbols))[0]
    four_sub_three = list(
        filter(lambda s: s.mathml == "<msub>\n<mn>4</mn>\n<mn>3</mn>\n</msub>",
               symbols))[0]

    assert len(x_sub_four_sub_three.children) == 2
    assert x in x_sub_four_sub_three.children
    assert four_sub_three in x_sub_four_sub_three.children
def test_get_symbol_children():
    with open(get_test_path(os.path.join("mathml",
                                         "x_sub_t_sub_i.xml"))) as mathml_file:
        mathml = mathml_file.read()
        symbols = get_symbols(mathml)

    assert len(symbols) == 5
    x_sub_t_sub_i = list(
        filter(lambda s: "msub" in s.mathml and "x" in s.mathml, symbols))[0]
    t_sub_i = list(
        filter(lambda s: "msub" in s.mathml and s is not x_sub_t_sub_i,
               symbols))[0]
    x = list(filter(lambda s: s.mathml == "<mi>x</mi>", symbols))[0]
    t = list(filter(lambda s: s.mathml == "<mi>t</mi>", symbols))[0]
    i = list(filter(lambda s: s.mathml == "<mi>i</mi>", symbols))[0]

    assert len(x_sub_t_sub_i.children) == 2
    assert x in x_sub_t_sub_i.children
    assert t_sub_i in x_sub_t_sub_i.children

    assert len(t_sub_i.children) == 2
    assert t in t_sub_i.children
    assert i in t_sub_i.children
def _get_symbol_data(arxiv_id: ArxivId, stdout: str) -> Iterator[SymbolData]:
    for result in stdout.strip().splitlines():
        data = json.loads(result)
        characters = None
        symbols = None

        if data["success"] is True:
            mathml = data["mathMl"]
            characters = get_characters(mathml)
            symbols = get_symbols(mathml)

        yield SymbolData(
            arxiv_id=arxiv_id,
            success=data["success"],
            equation_index=int(data["i"]),
            tex_path=data["tex_path"],
            equation=data["equation"],
            characters=characters,
            symbols=symbols,
            equation_start=int(data["equation_start"]),
            equation_depth=int(data["equation_depth"]),
            context_tex=data["context_tex"],
            errorMessage=data["errorMessage"],
        )