Esempio n. 1
0
def test_preprocess_filter(write_yaml, shared_datadir, add_cli_arguments):
    def duplicate_file(filename):
        with open(shared_datadir / filename, "r") as fileobj:
            lines = fileobj.read().splitlines()
        lines = lines + lines
        with open(shared_datadir / filename, "w") as fileobj:
            fileobj.write("\n".join(lines))

    config_path = write_yaml({
        "file_prefix":
        str(shared_datadir / "make_false"),
        "split_size": {
            "training": 0.6,
            "testing": 0.2,
            "validation": 0.2
        },
        "library_headers": [
            "index",
            "reaction_hash",
            "reactants",
            "products",
            "retro_template",
            "template_hash",
            "template_code",
        ],
    })

    add_cli_arguments(f"{config_path} strict")
    make_false_main()

    duplicate_file("make_false_template_library_false.csv")
    duplicate_file("make_false_template_library.csv")

    add_cli_arguments(config_path)
    filter_main()

    with open(shared_datadir / "make_false_training.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 6

    with open(shared_datadir / "make_false_testing.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 2

    with open(shared_datadir / "make_false_validation.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 2
Esempio n. 2
0
def test_make_false_products(write_yaml, shared_datadir, add_cli_arguments):
    config_path = write_yaml(
        {
            "file_prefix": str(shared_datadir / "make_false"),
            "library_headers": [
                "index",
                "reaction_hash",
                "reactants",
                "products",
                "retro_template",
                "template_hash",
                "template_code",
            ],
        }
    )
    add_cli_arguments(f"{config_path} strict")

    make_false_main()

    with open(shared_datadir / "make_false_template_library_false.csv", "r") as fileobj:
        lines = fileobj.read().splitlines()
    assert len(lines) == 2