def test_validate_args():
    """Test for existence check for gene expression file.
    
    Test if validation function raises an exception when gene expression file is missing.
    """
    with pytest.raises(ArgumentValidateException):
        args = parse_args(["mock_filename"])
        validate_args(args)
def test_validate_tpm_args(create_expression_files, create_output_files):
    """Test for existence of gene lengths file when TPM normalization type is used."""
    expression_path = create_expression_files[0]
    tpm_output_path = create_output_files[0]

    cmdline = [f"--tpm-output={tpm_output_path}", f"{expression_path}"]

    with pytest.raises(ArgumentValidateException):
        args = parse_args(cmdline)
        validate_args(args)
def test_parse_args():
    """Test argument parser."""
    cmdline = [
        "--gene-lengths=lengths.tsv", "--cpm-output=expr.cpm.tsv",
        "--tpm-output=expr.tpm.tsv", "expr.tsv"
    ]
    args = parse_args(cmdline)

    checks = [
        args.expression == "expr.tsv",
        args.gene_lengths == "lengths.tsv",
        args.tpm_output == "expr.tpm.tsv",
        args.cpm_output == "expr.cpm.tsv",
    ]
    assert all(checks)
def test_normalization_output_success(create_expression_files,
                                      create_output_files):
    """Test TPM and CPM output.
    
    This test implements the contents of main function and checks if the 
    normalization outputs are correctly written to files. 
    """
    expression_path, gene_lengths_path = create_expression_files
    tpm_output_path, cpm_output_path = create_output_files

    cmdline = [
        f"--gene-lengths={gene_lengths_path}",
        f"--cpm-output={cpm_output_path}",
        f"--tpm-output={tpm_output_path}",
        f"{expression_path}",
    ]

    args = parse_args(cmdline)

    # Test if validation passes when files exist
    validate_args(args)

    # Test if tpm and cpm can be computed and written to files
    expressions = load_expressions(expression_path)

    if cpm_output_path:
        CPM = cpm(expressions)
        CPM.to_csv(cpm_output_path, sep="\t")

    if tpm_output_path:
        gene_lengths = load_gene_lengths(gene_lengths_path)
        TPM = tpm(expressions, gene_lengths)
        TPM.to_csv(tpm_output_path, sep="\t")

    # Re-read the values in files and check if they are correct
    manually_computed_CPM = [[23809, 20000], [952380, 950000], [23809, 30000]]
    CPM = pd.DataFrame(pd.read_csv(cpm_output_path, sep="\t"),
                       columns=["S1", "S2"])
    assert np.all(np.asarray(CPM, dtype=int) == manually_computed_CPM)

    manually_computed_TPM = [[20704, 17400], [959266, 957349], [20029, 25250]]
    TPM = pd.DataFrame(pd.read_csv(tpm_output_path, sep="\t"),
                       columns=["S1", "S2"])
    assert np.all(np.asarray(TPM, dtype=int) == manually_computed_TPM)