예제 #1
0
class TestAnnotator(unittest.TestCase):
    """Test validation of parameters"""

    def setUp(self):
        self.data = [[1, 2, 3], [2, 5, 7]]
        self.data2 = pd.DataFrame([[1, 2], [2, 5], [3, 7]], columns=["X", "Y"])
        self.ax = sns.boxplot(data=self.data)
        self.df = pd.DataFrame.from_dict(
            {1: {'x': "a", 'y': 15, 'color': 'blue'},
             2: {'x': "a", 'y': 16, 'color': 'blue'},
             3: {'x': "b", 'y': 17, 'color': 'blue'},
             4: {'x': "b", 'y': 18, 'color': 'blue'},
             5: {'x': "a", 'y': 15, 'color': 'red'},
             6: {'x': "a", 'y': 16, 'color': 'red'},
             7: {'x': "b", 'y': 17, 'color': 'red'},
             8: {'x': "b", 'y': 18, 'color': 'red'}
             }).T

        self.pairs_for_df = [(("a", "blue"), ("b", "blue")),
                             (("a", "blue"), ("a", "red"))]
        self.df.y = self.df.y.astype(float)
        self.params_df = {
            "data": self.df,
            "x": "x",
            "y": "y",
            "hue": "color",
            "order": ["a", "b"],
            "hue_order": ['red', 'blue']}

    def test_init_simple(self):
        self.annot = Annotator(self.ax, [(0, 1)], data=self.data)

    def test_init_df(self):
        self.ax = sns.boxplot(**self.params_df)
        self.annot = Annotator(self.ax, pairs=self.pairs_for_df,
                               **self.params_df)

    def test_init_barplot(self):
        ax = sns.barplot(data=self.data)
        self.annot = Annotator(ax, [(0, 1)], plot="barplot", data=self.data)

    def test_test_name_provided(self):
        self.test_init_simple()
        with self.assertRaisesRegex(ValueError, "test"):
            self.annot.apply_test()

    def test_unmatched_x_in_box_pairs_without_hue(self):
        with self.assertRaisesRegex(ValueError, "(specified in `pairs`)"):
            self.annot = Annotator(self.ax, [(0, 2)], data=self.data)

    def test_order_in_x(self):
        with self.assertRaisesRegex(ValueError, "(specified in `order`)"):
            self.annot = Annotator(self.ax, [(0, 2)], data=self.data,
                                   order=[0, 1, 2])

    def test_working_hue_orders(self):
        self.annot = Annotator(self.ax, [(("a", "blue"), ("b", "blue"))],
                               data=self.df, x="x", y="y",
                               order=["a", "b"], hue='color',
                               hue_order=['red', 'blue'])

    def test_unmatched_hue_in_hue_order(self):
        with self.assertRaisesRegex(ValueError, "(specified in `hue_order`)"):
            self.annot = Annotator(self.ax, [(("a", "blue"), ("b", "blue"))],
                                   data=self.df, x="x", y="y",
                                   order=["a", "b"], hue='color',
                                   hue_order=['red', 'yellow'])

    def test_unmatched_hue_in_box_pairs(self):
        with self.assertRaisesRegex(ValueError, "(specified in `pairs`)"):
            self.annot = Annotator(self.ax, [(("a", "yellow"), ("b", "blue"))],
                                   data=self.df, x="x", y="y",
                                   order=["a", "b"], hue='color',
                                   hue_order=['red', 'blue'])

    def test_unmatched_x_in_box_pairs_with_hue(self):
        with self.assertRaisesRegex(ValueError, "(specified in `pairs`)"):
            self.annot = Annotator(self.ax, [(("c", "blue"), ("b", "blue"))],
                                   data=self.df, x="x", y="y",
                                   order=["a", "b"], hue='color',
                                   hue_order=['red', 'blue'])

    def test_location(self):
        self.test_init_simple()
        with self.assertRaisesRegex(ValueError, "argument `loc`"):
            self.annot.configure(loc="somewhere")

    def test_unknown_parameter(self):
        self.test_init_simple()
        with self.assertRaisesRegex(
                InvalidParametersError, re.escape("parameter(s) \"that\"")):
            self.annot.configure(that="this")

    def test_format(self):
        self.test_init_simple()

        with self.assertRaisesRegex(ValueError, "argument `text_format`"):
            self.annot.configure(pvalue_format={'text_format': 'that'})

    def test_apply_comparisons_correction(self):
        self.test_init_simple()
        self.assertIsNone(self.annot._apply_comparisons_correction([]))

    def test_correct_num_custom_annotations(self):
        self.test_init_simple()
        with self.assertRaisesRegex(ValueError, "same length"):
            self.annot.set_custom_annotations(["One", "Two"])

    def test_not_implemented_plot(self):
        with self.assertRaises(NotImplementedError):
            Annotator(self.ax, [(0, 1)], data=self.data, plot="thatplot")

    def test_reconfigure_alpha(self):
        self.test_init_simple()
        with self.assertWarnsRegex(UserWarning, "pvalue_thresholds"):
            self.annot.configure(alpha=0.1)
        self.annot.reset_configuration()
        self.assertEqual(0.05, self.annot.alpha)

    def test_reconfigure_alpha_with_thresholds(self):
        self.test_init_simple()
        self.annot.configure(alpha=0.1,
                             pvalue_format={"pvalue_thresholds": DEFAULT})
        self.annot.reset_configuration()
        self.assertEqual(0.05, self.annot.alpha)

    def test_get_annotation_text_undefined(self):
        self.test_init_simple()
        self.assertIsNone(self.annot.get_annotations_text())

    def test_get_annotation_text_calculated(self):
        self.test_init_simple()
        self.annot.configure(test="Mann-Whitney", verbose=2)
        self.annot.apply_test()
        self.assertEqual(["ns"], self.annot.get_annotations_text())

    def test_get_annotation_text_in_input_order(self):
        self.test_init_df()
        self.annot.configure(test="Mann-Whitney", text_format="simple")
        self.annot.apply_test()

        expected = (['M.W.W. p = 0.25', 'M.W.W. p = 0.67']
                    if version.parse(scipy.__version__) < version.parse("1.7")
                    else ['M.W.W. p = 0.33', 'M.W.W. p = 1.00'])
        self.assertEqual(expected, self.annot.get_annotations_text())

    def test_init_df_inverted(self):
        box_pairs = self.pairs_for_df[::-1]
        self.ax = sns.boxplot(**self.params_df)
        self.annot = Annotator(self.ax, pairs=box_pairs, **self.params_df)

    def test_get_annotation_text_in_input_order_inverted(self):
        self.test_init_df_inverted()
        self.annot.configure(test="Mann-Whitney", text_format="simple")
        self.annot.apply_test()

        expected = (['M.W.W. p = 0.67', 'M.W.W. p = 0.25']
                    if version.parse(scipy.__version__) < version.parse("1.7")
                    else ['M.W.W. p = 1.00', 'M.W.W. p = 0.33'])
        self.assertEqual(expected, self.annot.get_annotations_text())

    def test_apply_no_apply_warns(self):
        self.test_init_df_inverted()
        self.annot.configure(test="Mann-Whitney", text_format="simple")
        self.annot.apply_and_annotate()

        self.ax = sns.boxplot(**self.params_df)
        self.annot.new_plot(self.ax, self.pairs_for_df, **self.params_df)
        self.annot.configure(test="Levene", text_format="simple")
        with self.assertWarns(UserWarning):
            self.annot.annotate()

    def test_apply_apply_no_warns(self):
        self.test_init_df_inverted()
        self.annot.configure(test="Mann-Whitney", text_format="simple")
        self.annot.apply_and_annotate()

        self.ax = sns.boxplot(**self.params_df)
        self.annot.new_plot(self.ax, self.pairs_for_df, **self.params_df)
        self.annot.configure(test="Mann-Whitney-gt", text_format="simple")
        self.annot.apply_and_annotate()

    def test_valid_parameters_df_data_only(self):
        self.ax = sns.boxplot(ax=self.ax, data=self.data2)
        annot = Annotator(self.ax, pairs=[("X", "Y")],
                          data=self.data2)
        annot.configure(test="Mann-Whitney").apply_and_annotate()

    def test_comparisons_correction_by_name(self):
        self.ax = sns.boxplot(ax=self.ax, data=self.data2)
        annot = Annotator(self.ax, pairs=[("X", "Y")],
                          data=self.data2)
        annot.configure(test="Mann-Whitney", comparisons_correction="BH")
        annot.apply_and_annotate()

    def test_empty_annotator_wo_new_plot_raises(self):
        annot = Annotator.get_empty_annotator()
        with self.assertRaises(RuntimeError):
            annot.configure(test="Mann-Whitney")

    def test_empty_annotator_then_new_plot_ok(self):
        annot = Annotator.get_empty_annotator()
        self.ax = sns.boxplot(ax=self.ax, data=self.data2)
        annot.new_plot(self.ax, pairs=[("X", "Y")],
                       data=self.data2)
        annot.configure(test="Mann-Whitney")

    def test_ensure_ax_operation_format_args_not_ok(self):
        with self.assertRaises(ValueError):
            _ensure_ax_operation_format(["func", "param", None])

    def test_ensure_ax_operation_format_op_not_ok(self):
        with self.assertRaises(ValueError):
            _ensure_ax_operation_format(["func", ["param"]])

    def test_ensure_ax_operation_format_kwargs_not_ok(self):
        with self.assertRaises(ValueError):
            _ensure_ax_operation_format(["func", ["param"], {"that"}])

    def test_ensure_ax_operation_format_func_not_ok(self):
        with self.assertRaises(ValueError):
            _ensure_ax_operation_format([sum, ["param"], {"that": "this"}])
예제 #2
0
def main():
    # input_dir = "/Users/odedkushnir/Projects/fitness/AccuNGS/190627_RV_CV/RVB14/"
    # input_dir = "/Volumes/STERNADILABHOME$/volume3/okushnir/AccuNGS/20201008RV-202329127/merged/patients/"
    input_dir = "/Users/odedkushnir/PhD_Projects/After_review/AccuNGS/RV/patients/"
    prefix = "inosine_predict_context_freq0.01"
    date = datetime.today().strftime("%Y%m%d")
    output_dir = input_dir + "{0}_{1}".format(date, prefix)
    try:
        os.mkdir(output_dir)
    except OSError:
        print("Creation of the directory %s failed" % output_dir)
    else:
        print("Successfully created the directory %s " % output_dir)

    data_filter = pd.read_pickle(input_dir + prefix + "/data_filter.pkl")
    data_filter_ag = pd.read_pickle(input_dir + prefix + "/data_filter_ag.pkl")
    data_filter_uc = pd.read_pickle(input_dir + prefix + "/data_filter_uc.pkl")
    data_filter["label"] = np.where(
        data_filter["label"] == "RNA Control\nPrimer ID", "RNA\nControl",
        data_filter["label"])

    #Plots
    label_order = [
        "RNA\nControl", "p3 Cell Culture\nControl", "Patient-1", "Patient-4",
        "Patient-5", "Patient-9", "Patient-16", "Patient-17", "Patient-20"
    ]
    mutation_order = [
        "A>G", "U>C", "G>A", "C>U", "A>C", "U>G", "A>U", "U>A", "G>C", "C>G",
        "C>A", "G>U"
    ]
    transition_order = ["A>G", "U>C", "G>A", "C>U"]
    type_order1 = ["Synonymous", "Non-Synonymous", "Premature Stop Codon"]
    context_order = ["UpA", "ApA", "CpA", "GpA"]
    type_order2 = ["Synonymous", "Non-Synonymous"]
    context_order_uc = ["UpA", "UpU", "UpG", "UpC"]
    type_order_ag = ["Synonymous", "Non-Synonymous", "NonCodingRegion"]
    adar_preference = ["High", "Intermediate", "Low"]
    plus_minus = u"\u00B1"
    pairs = [(("RNA\nControl", "A>G"), ("RNA\nControl", "G>A")),
             (("p3 Cell Culture\nControl", "A>G"), ("p3 Cell Culture\nControl",
                                                    "G>A")),
             (("Patient-1", "A>G"), ("Patient-1", "G>A")),
             (("Patient-4", "A>G"), ("Patient-4", "G>A")),
             (("Patient-5", "A>G"), ("Patient-5", "G>A")),
             (("Patient-9", "A>G"), ("Patient-9", "G>A")),
             (("Patient-16", "A>G"), ("Patient-16", "G>A")),
             (("Patient-17", "A>G"), ("Patient-17", "G>A")),
             (("Patient-20", "A>G"), ("Patient-20", "G>A")),
             (("RNA\nControl", "A>G"), ("RNA\nControl", "U>C")),
             (("p3 Cell Culture\nControl", "A>G"), ("p3 Cell Culture\nControl",
                                                    "U>C")),
             (("Patient-1", "A>G"), ("Patient-1", "U>C")),
             (("Patient-4", "A>G"), ("Patient-4", "U>C")),
             (("Patient-5", "A>G"), ("Patient-5", "U>C")),
             (("Patient-9", "A>G"), ("Patient-9", "U>C")),
             (("Patient-16", "A>G"), ("Patient-16", "U>C")),
             (("Patient-17", "A>G"), ("Patient-17", "U>C")),
             (("Patient-20", "A>G"), ("Patient-20", "U>C")),
             (("RNA\nControl", "A>G"), ("RNA\nControl", "C>U")),
             (("p3 Cell Culture\nControl", "A>G"), ("p3 Cell Culture\nControl",
                                                    "C>U")),
             (("Patient-1", "A>G"), ("Patient-1", "C>U")),
             (("Patient-4", "A>G"), ("Patient-4", "C>U")),
             (("Patient-5", "A>G"), ("Patient-5", "C>U")),
             (("Patient-9", "A>G"), ("Patient-9", "C>U")),
             (("Patient-16", "A>G"), ("Patient-16", "C>U")),
             (("Patient-17", "A>G"), ("Patient-17", "C>U")),
             (("Patient-20", "A>G"), ("Patient-20", "C>U"))]

    # g1 = sns.catplot(x="label", y="frac_and_weight", data=data_filter, hue="Mutation", order=label_order, palette="tab20",
    #                     kind="point", dodge=True, hue_order=mutation_order, join=False, estimator=weighted_varaint,
    #                  orient="v")
    # g1.set_axis_labels("", "Variant Frequency {} CI=95%".format(plus_minus))
    # g1.set_xticklabels(fontsize=9, rotation=90)
    # g1.set(yscale='log')
    # # g1.set(ylim=(10**-7, 10**-3))
    #
    # # plt.show()
    # g1.savefig(output_dir + "/All_Mutations_point_plot", dpi=300)
    # plt.close()
    g2 = sns.catplot(x="label",
                     y="frac_and_weight",
                     data=data_filter,
                     hue="Mutation",
                     order=label_order,
                     palette=mutation_palette(4),
                     kind="point",
                     dodge=0.5,
                     hue_order=transition_order,
                     join=False,
                     estimator=weighted_varaint,
                     orient="v",
                     legend=True)
    g2.set_axis_labels("", "Variant Frequency {} CI=95%".format(plus_minus))
    g2.set(yscale='log')
    g2.set(ylim=(10**-5, 10**-3))
    # g2.set_yticklabels(fontsize=12)
    g2.set_xticklabels(fontsize=10, rotation=90)
    # plt.show()
    # g2.savefig("/Users/odedkushnir/Google Drive/Studies/PhD/MyPosters/20190924 GGE/plots/Transition_Mutations_point_plot_RV", dpi=300)
    g2.savefig(output_dir + "/Transition_Mutations_point_plot", dpi=300)
    # g2.savefig("/Users/odedkushnir/Google Drive/Studies/PhD/Prgress reports/20200913 Final report/plots" +
    #                   "/Fig9a_Transition_Mutations_point_plot_Patients", dpi=300)
    plt.close()
    data_filter["label"] = data_filter["label"].astype(str)
    data_filter["Frequency"] = data_filter["Frequency"].astype(float)
    passage_g = sns.boxplot(x="label",
                            y="Frequency",
                            data=data_filter,
                            hue="Mutation",
                            order=label_order,
                            palette=mutation_palette(4),
                            dodge=True,
                            hue_order=transition_order)
    passage_g.set_yscale('log')
    passage_g.set_ylim(10**-6, 10**-1)
    passage_g.set(xlabel="", ylabel="Variant Frequency")
    passage_g.set_xticklabels(labels=label_order, fontsize=10, rotation=90)

    annot = Annotator(passage_g,
                      pairs,
                      x="label",
                      y="Frequency",
                      hue="Mutation",
                      data=data_filter,
                      order=label_order,
                      hue_order=transition_order)
    annot.configure(test='t-test_welch',
                    text_format='star',
                    loc='outside',
                    verbose=2,
                    comparisons_correction="Bonferroni")
    annot.apply_test()
    file_path = output_dir + "/sts.csv"
    with open(file_path, "w") as o:
        with contextlib.redirect_stdout(o):
            passage_g, test_results = annot.annotate()
    plt.legend(bbox_to_anchor=(1.05, 0.5), loc=2, borderaxespad=0.)
    plt.tight_layout()
    plt.savefig(output_dir + "/Transition_Mutations_box_stat_plot_patients",
                dpi=300)
    plt.close()

    # g_rna = sns.catplot(x="RNA", y="frac_and_weight", data=data_filter, hue="Mutation", order=rna_order,
    #                  palette="tab20", kind="point", dodge=True, hue_order=transition_order, join=False, estimator=weighted_varaint,
    #                  orient="v", legend=True)
    # g_rna.set_axis_labels("", "Variant Frequency")
    # g_rna.set(yscale='log')
    # g_rna.set(ylim=(10 ** -6, 10 ** -2))
    # # g2.set_yticklabels(fontsize=12)
    # g_rna.set_xticklabels(fontsize=10, rotation=45)
    # plt.show()
    # g2.savefig("/Users/odedkushnir/Google Drive/Studies/PhD/MyPosters/20190924 GGE/plots/Transition_Mutations_point_plot_RV", dpi=300)
    # g_rna.savefig(output_dir + "/Transition_Mutations_point_RNA_plot", dpi=300)
    # plt.close()

    # A>G Prev Context
    flatui = ["#3498db", "#9b59b6"]
    g5 = sns.catplot("label",
                     "frac_and_weight",
                     data=data_filter_ag,
                     hue="ADAR_like",
                     order=label_order,
                     palette=mutation_palette(2),
                     kind="point",
                     dodge=True,
                     hue_order=[True, False],
                     estimator=weighted_varaint,
                     orient="v",
                     col="Type",
                     join=False,
                     col_order=type_order2)
    g5.set_axis_labels("", "Variant Frequency {} CI=95%".format(plus_minus))
    g5.set(yscale='log')
    g5.set(ylim=(7 * 10**-7, 4 * 10**-3))
    g5.set_xticklabels(rotation=90)
    # plt.show()
    g5.savefig(output_dir + "/Context_point_plot", dpi=300)
    # g5.savefig("/Users/odedkushnir/Google Drive/Studies/PhD/Prgress reports/20200913 Final report/plots" +
    #            "/Fig9b_Context_point_plot_Patients", dpi=300)
    plt.close()

    mutation_ag = sns.catplot("label",
                              "frac_and_weight",
                              data=data_filter_ag,
                              hue="5`_ADAR_Preference",
                              palette=mutation_palette(3, adar=True, ag=True),
                              kind="point",
                              dodge=True,
                              estimator=weighted_varaint,
                              order=label_order,
                              orient="v",
                              col="Type",
                              join=False,
                              col_order=type_order_ag,
                              hue_order=adar_preference)
    mutation_ag.set(yscale="log")
    mutation_ag.set(ylim=(1 * 10**-5, 1 * 10**-2))
    mutation_ag.set_xticklabels(rotation=90)
    mutation_ag.fig.suptitle("A>G ADAR_like Mutation in RV patients", y=0.99)
    plt.subplots_adjust(top=0.85)
    mutation_ag.set_axis_labels(
        "", "Variant Frequency {} CI=95%".format(plus_minus))
    mutation_ag.savefig(output_dir + "/ag_ADAR_like_Mutation_col_patients.png",
                        dpi=300)
    plt.close()

    g6 = sns.catplot("label",
                     "frac_and_weight",
                     data=data_filter_ag,
                     hue="ADAR_like",
                     order=label_order,
                     palette=mutation_palette(2),
                     kind="point",
                     dodge=True,
                     hue_order=[True, False],
                     estimator=weighted_varaint,
                     orient="v",
                     join=False)
    g6.set_axis_labels("", "Variant Frequency {} CI=95%".format(plus_minus))
    g6.set(yscale='log')
    g6.set(ylim=(7 * 10**-7, 4 * 10**-3))
    g6.set_xticklabels(rotation=90)
    # plt.show()
    g6.savefig(output_dir + "/Context_point_all_mutations_type_plot", dpi=300)
    plt.close()

    g9 = sns.catplot("label",
                     "frac_and_weight",
                     data=data_filter_uc,
                     hue="Next",
                     order=label_order,
                     palette="tab20",
                     hue_order=context_order_uc,
                     estimator=weighted_varaint,
                     orient="v",
                     dodge=True,
                     kind="point",
                     col="Type",
                     join=False,
                     col_order=type_order2)
    g9.set_axis_labels("", "Variant Frequency {} CI=95%".format(plus_minus))
    g9.set(yscale='log')
    g9.set(ylim=(10**-5, 10**-2))
    g9.set_xticklabels(rotation=90)
    # plt.show()
    g9.savefig(output_dir + "/UC_Context_point_plot", dpi=300)
    plt.close()

    data_filter_ag_grouped = data_filter_ag.groupby(
        ["ADAR_like", "label",
         "Type"])["frac_and_weight"].agg(lambda x: weighted_varaint(x))
    data_filter_ag_grouped = data_filter_ag_grouped.reset_index()
    data_filter_ag_grouped = data_filter_ag_grouped.rename(
        columns={"frac_and_weight": "Frequency"})
    data_filter_ag_grouped["Frequency"] = data_filter_ag_grouped[
        "Frequency"].astype(float)
    print(data_filter_ag_grouped.to_string())

    data_filter_ag_grouped_silent = data_filter_ag_grouped[
        data_filter_ag_grouped["Type"] == "Synonymous"]
    data_filter_ag_grouped_silent = data_filter_ag_grouped_silent[
        data_filter_ag_grouped_silent["label"] == "Cell Cultureֿ\nControl"]
예제 #3
0
def plots(input_dir, date, data_filter, virus, passage_order, transition_order, pairs, label_order, pairs_adar, filter_reads=None):
    output_dir = input_dir + date + "_plots"
    plus_minus = u"\u00B1"
    try:
        os.mkdir(output_dir)
    except OSError:
        print("Creation of the directory %s failed" % output_dir)
    else:
        print("Successfully created the directory %s " % output_dir)
    if filter_reads is True:
        data_filter["no_variants"] = np.where(data_filter["Prob"] < 0.95, 0, data_filter["no_variants"])
        data_filter["Read_count"] = data_filter[data_filter["Read_count"] > 10000]
    mutation_order = ["A>G", "U>C", "G>A", "C>U", "A>C", "U>G", "A>U", "U>A", "G>C", "C>G", "C>A", "G>U"]
    type_order = ["Synonymous", "Non-Synonymous", "Premature Stop Codon"]
    # g1 = sns.catplot("label", "frac_and_weight", data=data_filter, hue="Mutation", order=label_order, palette="tab20",
    #                     kind="point", dodge=True, hue_order=mutation_order, join=False, estimator=weighted_varaint,
    #                  orient="v")
    # g1.set_axis_labels("Passage", "Variant Frequency {} CI=95%".format(plus_minus))
    # g1.set_xticklabels(fontsize=9, rotation=45)
    # g1.set(yscale='log')
    # g1.set(ylim=(10**-5, 10**-1))
    #
    # # plt.show()
    # g1.savefig(output_dir + "/All_Mutations_point_plot", dpi=300)
    # plt.close()

    data_filter["passage"] = data_filter["passage"].astype(str)
    data_filter["passage"] = np.where(data_filter["passage"] != "RNA\nControl", "p" + data_filter["passage"], data_filter["passage"])
    g2 = sns.catplot("passage", "frac_and_weight", data=data_filter, hue="Mutation", order=passage_order,
                     palette=mutation_palette(4)
                     , kind="point", dodge=0.5, hue_order=transition_order, join=False, estimator=weighted_varaint,
                     orient="v")
    g2.set_axis_labels("Passage", "Variant Frequency {} CI=95%".format(plus_minus))
    g2.set(yscale='log')
    g2.set(ylim=(10 ** -6, 10 ** -2))
    # g2.set_xticklabels(fontsize=10, rotation=45)
    # g2.savefig("/Users/odedkushnir/Google Drive/Studies/PhD/Prgress reports/20200913 Final report/plots" +
    #                   "/Transition_Mutations_point_plot_Mahoney", dpi=300)
    g2.savefig(output_dir + "/Transition_Mutations_point_plot_{0}".format(virus), dpi=300)
    plt.close()

    passage_g = sns.boxplot(x="passage", y="Frequency", data=data_filter, hue="Mutation", order=passage_order,
                            palette=mutation_palette(4), dodge=True, hue_order=transition_order)
    passage_g.set_yscale('log')
    passage_g.set_ylim(10 ** -6, 10 ** -1)
    passage_g.set(xlabel="Passage", ylabel="Variant Frequency")

    annot = Annotator(passage_g, pairs, x="passage", y="Frequency", hue="Mutation", data=data_filter,
                      order=passage_order, hue_order=transition_order)
    annot.configure(test='t-test_welch', text_format='star', loc='outside', verbose=2,
                    comparisons_correction="Bonferroni")
    annot.apply_test()
    file_path = output_dir + "/sts.csv"
    with open(file_path, "w") as o:
        with contextlib.redirect_stdout(o):
            passage_g, test_results = annot.annotate()
    plt.legend(bbox_to_anchor=(1.05, 0.5), loc=2, borderaxespad=0.)
    plt.tight_layout()
    plt.savefig(output_dir + "/Transition_Mutations_box_stat_plot_{0}".format(virus), dpi=300)
    plt.close()

    data_filter_synonymous = data_filter.loc[data_filter.Type == "Synonymous"]
    data_filter_synonymous["Mutation"] = np.where(((data_filter_synonymous["Mutation"] == "A>G") &
                                                   (data_filter_synonymous["5`_ADAR_Preference"] == "High")),
                                                  "High\nADAR-like\nA>G", np.where(((data_filter_synonymous["Mutation"] == "A>G")
                                                                                    & (data_filter_synonymous["5`_ADAR_Preference"] == "Intermediate")),
                                                                                   "Intermediate\nADAR-like\nA>G",
                                                                                   np.where(((data_filter_synonymous["Mutation"] == "A>G") &
                                                                                             (data_filter_synonymous["5`_ADAR_Preference"] == "Low")),
                                                                                            "Low\nADAR-like\nA>G",
                                                                                            data_filter_synonymous["Mutation"])))
    data_filter_synonymous["Mutation_adar"] = np.where(((data_filter_synonymous["Mutation"] == "U>C") &
                                                        (data_filter_synonymous["3`_ADAR_Preference"] == "High")),
                                                       "High\nADAR-like\nU>C", np.where(((data_filter_synonymous["Mutation"] == "U>C")
                                                                                         & (data_filter_synonymous["3`_ADAR_Preference"] == "Intermediate")),
                                                                                        "Intermediate\nADAR-like\nU>C",
                                                                                        np.where(((data_filter_synonymous["Mutation"] == "U>C") &
                                                                                                  (data_filter_synonymous["3`_ADAR_Preference"] == "Low")),
                                                                                                 "Low\nADAR-like\nU>C",
                                                                                                 data_filter_synonymous["Mutation"])))
    mutation_adar_order = ["High\nADAR-like\nA>G", "Low\nADAR-like\nA>G",
                           "High\nADAR-like\nU>C", "Low\nADAR-like\nU>C"]

    data_filter_synonymous["passage"] = data_filter_synonymous["passage"].astype(str)
    catplot_adar = sns.catplot(x="passage", y="frac_and_weight", data=data_filter_synonymous, hue="Mutation_adar",
                               order=passage_order, palette=mutation_palette(4, adar=True), kind="point", dodge=0.5,
                               hue_order=mutation_adar_order, join=False, estimator=weighted_varaint, orient="v",
                               legend=True)
    catplot_adar.set_axis_labels("Passage", "Variant Frequency {0} CI=95%".format(plus_minus))
    catplot_adar.set(yscale='log')
    catplot_adar.set(ylim=(10 ** -6, 10 ** -2))
    plt.savefig(output_dir + "/adar_pref_mutation_point_plot_{0}.png".format(virus), dpi=300)
    plt.close()

    adar_g = sns.boxplot(x="passage", y="Frequency", data=data_filter_synonymous, hue="Mutation_adar",
                         order=passage_order, palette=mutation_palette(4, adar=True), dodge=True,
                         hue_order=mutation_adar_order)
    adar_g.set_yscale('log')
    adar_g.set_ylim(10 ** -6, 10 ** -1)
    adar_g.set(xlabel="Passage", ylabel="Variant Frequency")
    annot = Annotator(adar_g, pairs_adar, x="passage", y="Frequency", hue="Mutation_adar",
                      data=data_filter_synonymous, hue_order=mutation_adar_order, order=passage_order)
    annot.configure(test='t-test_welch', text_format='star', loc='outside', verbose=2,
                    comparisons_correction="Bonferroni")
    annot.apply_test()
    file_path = output_dir + "/sts_adar.csv"
    with open(file_path, "w") as o:
        with contextlib.redirect_stdout(o):
            adar_g, test_results = annot.annotate()
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.tight_layout()
    plt.savefig(output_dir + "/adar_pref_mutation_box_plot_{0}.png".format(virus), dpi=300)
    plt.close()
예제 #4
0
def main():
    # input_dir = "/Volumes/STERNADILABHOME$/volume3/okushnir/AccuNGS/20201008RV-202329127/merged/passages/"
    """Local"""
    input_dir = "/Users/odedkushnir/PhD_Projects/After_review/AccuNGS/RV/passages/"
    prefix = "inosine_predict_context"
    date = datetime.today().strftime("%Y%m%d")
    output_dir = input_dir + "{0}_{1}".format(date, prefix)
    try:
        os.mkdir(output_dir)
    except OSError:
        print("Creation of the directory %s failed" % output_dir)
    else:
        print("Successfully created the directory %s " % output_dir)

    data_filter = pd.read_pickle(input_dir + prefix + "/data_filter.pkl")
    data_filter_ag = pd.read_pickle(input_dir + prefix + "/data_filter_ag.pkl")
    data_filter_uc = pd.read_pickle(input_dir + prefix + "/data_filter_uc.pkl")
    data_filter["passage"] = data_filter["passage"].astype(int)
    data_filter["no_variants"] = np.where(data_filter["Prob"] < 0.95, 0,
                                          data_filter["no_variants"])
    data_filter["Read_count"] = data_filter[data_filter["Read_count"] > 10000]

    #Plots
    label_order = [
        "RNA Control\nRND", "RNA Control\nPrimer ID", "p2-1", "p2-2", "p2-3",
        "p5-1", "p5-2", "p5-3", "p8-1", "p8-2", "p8-3", "p10-2", "p10-3",
        "p12-1", "p12-2", "p12-3"
    ]
    mutation_order = [
        "A>G", "U>C", "G>A", "C>U", "A>C", "U>G", "A>U", "U>A", "G>C", "C>G",
        "C>A", "G>U"
    ]
    transition_order = ["A>G", "U>C", "G>A", "C>U"]
    type_order = ["Synonymous", "Non-Synonymous", "Premature Stop Codon"]
    type_order_ag = ["Synonymous", "Non-Synonymous"]
    context_order = ["UpA", "ApA", "CpA", "GpA"]
    context_order_uc = ["UpU", "UpA", "UpC", "UpG"]
    adar_preference = ["High", "Intermediate", "Low"]
    plus_minus = u"\u00B1"

    # g1 = sns.catplot(x="label", y="frac_and_weight", data=data_filter, hue="Mutation", order=label_order,
    #                  palette="Set2",
    #                  kind="point", dodge=False, hue_order=mutation_order, join=True, estimator=weighted_varaint,
    #                  orient="v")
    # g1.set_axis_labels("", "Variant Frequency")
    # g1.set_xticklabels(fontsize=9, rotation=45)
    # g1.set(yscale='log')
    # g1.set(ylim=(10 ** -7, 10 ** -3))
    #
    # # plt.show()
    # g1.savefig(output_dir + "/All_Mutations_point_plot", dpi=300)
    # plt.close()
    #
    # g2 = sns.catplot(x="label", y="frac_and_weight", data=data_filter, hue="Mutation", order=label_order,
    #                  palette=mutation_palette(4), kind="point", dodge=True, hue_order=transition_order, join=False,
    #                  estimator=weighted_varaint,
    #                  orient="v", legend=True)
    # g2.set_axis_labels("", "Variant Frequency")
    # g2.set(yscale='log', ylim=(10 ** -6, 10 ** -2), xlim=(0, 12, 2))
    # # g2.set_yticklabels(fontsize=12)
    # g2.set_xticklabels(fontsize=9, rotation=90)
    # plt.show()
    # g2.savefig("/Users/odedkushnir/Google Drive/Studies/PhD/MyPosters/20190924 GGE/plots/Transition_Mutations_point_plot_RV", dpi=300)
    # g2.savefig(output_dir + "/Transition_Mutations_point_plot", dpi=300)
    # plt.close()
    replica_lst = [1, 2, 3]
    for replica in replica_lst:
        data_filter_replica = data_filter[data_filter["replica"] == replica]
        data_filter_replica["passage"] = data_filter_replica["passage"].astype(
            str)
        data_filter_replica["passage"] = "p" + data_filter_replica["passage"]
        if replica == 2:
            data_filter_replica = pd.read_pickle(input_dir + prefix +
                                                 "/data_filter.pkl")
            data_filter_replica["passage"] = data_filter_replica[
                "passage"].astype(int)
            data_filter_replica["no_variants"] = np.where(
                data_filter_replica["Prob"] < 0.95, 0,
                data_filter_replica["no_variants"])
            data_filter_replica["Read_count"] = data_filter_replica[
                data_filter_replica["Read_count"] > 10000]
            data_filter_replica["passage"] = data_filter_replica[
                "passage"].astype(str)
            data_filter_replica[
                "passage"] = "p" + data_filter_replica["passage"]
            data_filter_replica["replica"] = np.where(
                data_filter_replica["passage"] == "p0", 2,
                data_filter_replica["replica"])
            data_filter_replica = data_filter_replica[
                data_filter_replica["replica"] == replica]
        data_filter_replica["passage"] = np.where(
            data_filter_replica["passage"] == "p0", "RNA\nControl",
            data_filter_replica["passage"])

        if replica == 1:
            passage_order = ["RNA\nControl", "p2", "p5", "p8", "p12"]
            pairs = [(("RNA\nControl", "A>G"), ("RNA\nControl", "G>A")),
                     (("p2", "A>G"), ("p2", "G>A")),
                     (("p5", "A>G"), ("p5", "G>A")),
                     (("p8", "A>G"), ("p8", "G>A")),
                     (("p12", "A>G"), ("p12", "G>A")),
                     (("RNA\nControl", "A>G"), ("RNA\nControl", "U>C")),
                     (("p2", "A>G"), ("p2", "U>C")),
                     (("p5", "A>G"), ("p5", "U>C")),
                     (("p8", "A>G"), ("p8", "U>C")),
                     (("p12", "A>G"), ("p12", "U>C")),
                     (("RNA\nControl", "A>G"), ("RNA\nControl", "C>U")),
                     (("p2", "A>G"), ("p2", "C>U")),
                     (("p5", "A>G"), ("p5", "C>U")),
                     (("p8", "A>G"), ("p8", "C>U")),
                     (("p12", "A>G"), ("p12", "C>U"))]
            pairs_adar = [(("RNA\nControl", "High\nADAR-like\nA>G"),
                           ("RNA\nControl", "Low\nADAR-like\nA>G")),
                          (("p2", "High\nADAR-like\nA>G"),
                           ("p2", "Low\nADAR-like\nA>G")),
                          (("p5", "High\nADAR-like\nA>G"),
                           ("p5", "Low\nADAR-like\nA>G")),
                          (("p8", "High\nADAR-like\nA>G"),
                           ("p8", "Low\nADAR-like\nA>G")),
                          (("p12", "High\nADAR-like\nA>G"),
                           ("p12", "Low\nADAR-like\nA>G")),
                          (("p2", "High\nADAR-like\nU>C"),
                           ("p2", "Low\nADAR-like\nU>C")),
                          (("p5", "High\nADAR-like\nU>C"),
                           ("p5", "Low\nADAR-like\nU>C")),
                          (("p8", "High\nADAR-like\nU>C"),
                           ("p8", "Low\nADAR-like\nU>C")),
                          (("p12", "High\nADAR-like\nU>C"),
                           ("p12", "Low\nADAR-like\nU>C"))]
        else:
            passage_order = ["RNA\nControl", "p2", "p5", "p8", "p10", "p12"]
            pairs = [(("RNA\nControl", "A>G"), ("RNA\nControl", "G>A")),
                     (("p2", "A>G"), ("p2", "G>A")),
                     (("p5", "A>G"), ("p5", "G>A")),
                     (("p8", "A>G"), ("p8", "G>A")),
                     (("p10", "A>G"), ("p10", "G>A")),
                     (("p12", "A>G"), ("p12", "G>A")),
                     (("RNA\nControl", "A>G"), ("RNA\nControl", "U>C")),
                     (("p2", "A>G"), ("p2", "U>C")),
                     (("p5", "A>G"), ("p5", "U>C")),
                     (("p8", "A>G"), ("p8", "U>C")),
                     (("p10", "A>G"), ("p10", "U>C")),
                     (("p12", "A>G"), ("p12", "U>C")),
                     (("RNA\nControl", "A>G"), ("RNA\nControl", "C>U")),
                     (("p2", "A>G"), ("p2", "C>U")),
                     (("p5", "A>G"), ("p5", "C>U")),
                     (("p8", "A>G"), ("p8", "C>U")),
                     (("p10", "A>G"), ("p10", "C>U")),
                     (("p12", "A>G"), ("p12", "C>U"))]
            pairs_adar = [(("RNA\nControl", "High\nADAR-like\nA>G"),
                           ("RNA\nControl", "Low\nADAR-like\nA>G")),
                          (("p2", "High\nADAR-like\nA>G"),
                           ("p2", "Low\nADAR-like\nA>G")),
                          (("p5", "High\nADAR-like\nA>G"),
                           ("p5", "Low\nADAR-like\nA>G")),
                          (("p8", "High\nADAR-like\nA>G"),
                           ("p8", "Low\nADAR-like\nA>G")),
                          (("p10", "High\nADAR-like\nA>G"),
                           ("p10", "Low\nADAR-like\nA>G")),
                          (("p12", "High\nADAR-like\nA>G"),
                           ("p12", "Low\nADAR-like\nA>G")),
                          (("RNA\nControl", "High\nADAR-like\nU>C"),
                           ("RNA\nControl", "Low\nADAR-like\nU>C")),
                          (("p2", "High\nADAR-like\nU>C"),
                           ("p2", "Low\nADAR-like\nU>C")),
                          (("p5", "High\nADAR-like\nU>C"),
                           ("p5", "Low\nADAR-like\nU>C")),
                          (("p8", "High\nADAR-like\nU>C"),
                           ("p8", "Low\nADAR-like\nU>C")),
                          (("p10", "High\nADAR-like\nU>C"),
                           ("p10", "Low\nADAR-like\nU>C")),
                          (("p12", "High\nADAR-like\nU>C"),
                           ("p12", "Low\nADAR-like\nU>C"))]

        passage_g = sns.catplot(x="passage",
                                y="frac_and_weight",
                                data=data_filter_replica,
                                hue="Mutation",
                                order=passage_order,
                                palette=mutation_palette(4),
                                kind="point",
                                dodge=0.5,
                                hue_order=transition_order,
                                join=False,
                                estimator=weighted_varaint,
                                orient="v",
                                legend=True)
        passage_g.set_axis_labels(
            "Passage", "Variant Frequency {} CI=95%".format(plus_minus))
        passage_g.set(yscale='log', ylim=(10**-6, 10**-2))
        plt.savefig(
            output_dir +
            "/Transition_Mutations_point_plot_RVB14_replica%s" % str(replica),
            dpi=300)
        plt.close()

        passage_g1 = sns.boxplot(x="passage",
                                 y="Frequency",
                                 data=data_filter_replica,
                                 hue="Mutation",
                                 order=passage_order,
                                 palette=mutation_palette(4),
                                 dodge=True,
                                 hue_order=transition_order)
        passage_g1.set_yscale('log')
        passage_g1.set_ylim(10**-6, 10**-2)
        passage_g1.set(xlabel="Passage", ylabel="Variant Frequency")
        annot = Annotator(passage_g1,
                          pairs,
                          x="passage",
                          y="Frequency",
                          hue="Mutation",
                          data=data_filter_replica,
                          order=passage_order,
                          hue_order=transition_order)
        annot.configure(test='t-test_welch',
                        text_format='star',
                        loc='outside',
                        verbose=2,
                        comparisons_correction="Bonferroni")
        annot.apply_test()
        file_path = output_dir + "/sts{0}.csv".format(replica)
        with open(file_path, "w") as o:
            with contextlib.redirect_stdout(o):
                passage_g1, test_results = annot.annotate()
        plt.legend(bbox_to_anchor=(1.05, 0.5), loc=2, borderaxespad=0.)
        plt.tight_layout()
        plt.savefig(
            output_dir +
            "/Transition_Mutations_box_stat_plot_RVB14_replica{0}".format(
                replica),
            dpi=300)
        plt.close()
        # data_filter["passage"] = data_filter["passage"].astype(int)
        #
        #
        # g4 = sns.relplot("passage", "frac_and_weight", data=data_filter, hue="Mutation", palette=mutation_palette(4),
        #                  hue_order=transition_order, estimator=weighted_varaint, col="Type", kind="line",
        #                  col_order=type_order)
        #
        # g4.axes.flat[0].set_yscale('symlog', linthreshy=10 ** -5)
        # g4.set_axis_labels("Passage", "Variant Frequency")
        # # plt.show()
        # g4.savefig(output_dir + "/Time_Transition_Mutations_line_plot", dpi=300)
        # plt.close()
        """ADAR preferences"""
        data_filter_replica_synonymous = data_filter_replica.loc[
            data_filter_replica.Type == "Synonymous"]
        # data_filter_synonymous["ADAR_like"] = (data_filter_synonymous.Prev.str.contains('UpA') | data_filter_synonymous.Prev.str.contains('ApA'))
        data_filter_replica_synonymous["Mutation"] = np.where(
            ((data_filter_replica_synonymous["Mutation"] == "A>G") &
             (data_filter_replica_synonymous["5`_ADAR_Preference"] == "High")),
            "High\nADAR-like\nA>G",
            np.where(
                ((data_filter_replica_synonymous["Mutation"] == "A>G")
                 & (data_filter_replica_synonymous["5`_ADAR_Preference"]
                    == "Intermediate")), "Intermediate\nADAR-like\nA>G",
                np.where(
                    ((data_filter_replica_synonymous["Mutation"] == "A>G") &
                     (data_filter_replica_synonymous["5`_ADAR_Preference"]
                      == "Low")), "Low\nADAR-like\nA>G",
                    data_filter_replica_synonymous["Mutation"])))
        data_filter_replica_synonymous["Mutation_adar"] = np.where(
            ((data_filter_replica_synonymous["Mutation"] == "U>C") &
             (data_filter_replica_synonymous["3`_ADAR_Preference"] == "High")),
            "High\nADAR-like\nU>C",
            np.where(
                ((data_filter_replica_synonymous["Mutation"] == "U>C")
                 & (data_filter_replica_synonymous["3`_ADAR_Preference"]
                    == "Intermediate")), "Intermediate\nADAR-like\nU>C",
                np.where(
                    ((data_filter_replica_synonymous["Mutation"] == "U>C") &
                     (data_filter_replica_synonymous["3`_ADAR_Preference"]
                      == "Low")), "Low\nADAR-like\nU>C",
                    data_filter_replica_synonymous["Mutation"])))
        mutation_adar_order = [
            "High\nADAR-like\nA>G", "Low\nADAR-like\nA>G",
            "High\nADAR-like\nU>C", "Low\nADAR-like\nU>C"
        ]
        # data_filter_replica_synonymous["passage"] = data_filter_replica_synonymous["passage"].astype(str)
        # data_filter_replica_synonymous["passage"] = "p" + data_filter_replica_synonymous["passage"]
        catplot_adar = sns.catplot(x="passage",
                                   y="frac_and_weight",
                                   data=data_filter_replica_synonymous,
                                   hue="Mutation_adar",
                                   order=passage_order,
                                   palette=mutation_palette(4, adar=True),
                                   kind="point",
                                   dodge=0.5,
                                   hue_order=mutation_adar_order,
                                   join=False,
                                   estimator=weighted_varaint,
                                   orient="v",
                                   legend=True)
        catplot_adar.set_axis_labels(
            "Passage", "Variant Frequency {} CI=95%".format(plus_minus))
        catplot_adar.set(yscale='log')
        catplot_adar.set(ylim=(10**-6, 10**-2))
        # catplot_adar.set_xticklabels(fontsize=8)
        # plt.tight_layout()
        plt.savefig(
            output_dir +
            "/adar_pref_mutation_point_plot_RVB14_replica{0}.png".format(
                replica),
            dpi=300)
        plt.close()

        adar_g = sns.boxplot(x="passage",
                             y="Frequency",
                             data=data_filter_replica_synonymous,
                             hue="Mutation_adar",
                             order=passage_order,
                             palette=mutation_palette(4, adar=True),
                             dodge=True,
                             hue_order=mutation_adar_order)
        adar_g.set_yscale('log')
        adar_g.set_ylim(10**-6, 10**-1)
        adar_g.set(xlabel="Passage", ylabel="Variant Frequency")

        annot = Annotator(adar_g,
                          pairs_adar,
                          x="passage",
                          y="Frequency",
                          hue="Mutation_adar",
                          data=data_filter_replica_synonymous,
                          hue_order=mutation_adar_order)
        annot.configure(test='t-test_welch',
                        text_format='star',
                        loc='outside',
                        verbose=2,
                        comparisons_correction="Bonferroni")
        annot.apply_test()
        file_path = output_dir + "/sts_adar_{0}.csv".format(replica)
        with open(file_path, "w") as o:
            with contextlib.redirect_stdout(o):
                adar_g, test_results = annot.annotate()
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        plt.tight_layout()
        plt.savefig(
            output_dir +
            "/adar_pref_mutation_box_stat_plot_RVB14_replica{0}".format(
                replica),
            dpi=300)
        plt.close()