Exemplo n.º 1
0
def main(config=None):
    if config is None:
        args = parse_args()
        # load model files
        assert os.path.exists(
            args.config), "Config file does not exist: {}".format(args.config)
        config = load_json(args.config)

    args = create_dot_dict(config)
    # load model files
    models = []
    kmer_lists = []
    assignment_data = []
    strands = []
    max_plots = 0
    # create models and grab kmer lists
    for model in args.models:
        models.append(
            HmmModel(ont_model_file=model.ont_model,
                     hdp_model_file=model.hdp_model,
                     nanopolish_model_file=model.nanopolish_model,
                     rna=model.rna,
                     name=model.name))
        model_kmer_list = model.kmers
        n_kmers_to_plot = len(model_kmer_list)
        kmer_lists.append(model_kmer_list)
        max_plots = n_kmers_to_plot if n_kmers_to_plot > max_plots else max_plots

        if model.builtAlignment_tsv is not None:
            assert os.path.exists(model.builtAlignment_tsv), \
                "builtAlignment_tsv does not exist: {}".format(model.builtAlignment_tsv)
            # read in both types of data
            try:
                assignment_data.append(
                    parse_assignment_file(model.builtAlignment_tsv))
            except ValueError:
                assignment_data.append(
                    parse_alignment_file(model.builtAlignment_tsv))
        else:
            assignment_data.append(None)
        strands.append(model.strand)

    mmh = MultipleModelHandler(models,
                               strands=strands,
                               assignment_data=assignment_data,
                               savefig_dir=args.save_fig_dir)
    if args.summary_distance:
        mmh.plot_all_model_comparisons()
    # Start plotting
    for kmer_list in zip_longest(*kmer_lists):
        mmh.plot_kmer_distribution(kmer_list)

    if args.save_fig_dir:
        save_json(
            args,
            os.path.join(args.save_fig_dir,
                         "compare_trained_models_config.json"))
Exemplo n.º 2
0
    def setUpClass(cls):
        super(SignalAlignmentTest, cls).setUpClass()
        cls.HOME = '/'.join(os.path.abspath(__file__).split("/")[:-4])
        cls.reference = os.path.join(
            cls.HOME, "tests/test_sequences/pUC19_SspI_Zymo.fa")
        cls.ecoli_reference = os.path.join(
            cls.HOME, "tests/test_sequences/E.coli_K12.fasta")

        cls.fast5_dir = os.path.join(
            cls.HOME, "tests/minion_test_reads/canonical_ecoli_R9")
        cls.files = [
            "miten_PC_20160820_FNFAD20259_MN17223_mux_scan_AMS_158_R9_WGA_Ecoli_08_20_16_83098_ch138_read23_strand.fast5",
            "miten_PC_20160820_FNFAD20259_MN17223_sequencing_run_AMS_158_R9_WGA_Ecoli_08_20_16_43623_ch101_read456_strand.fast5",
            "miten_PC_20160820_FNFAD20259_MN17223_sequencing_run_AMS_158_R9_WGA_Ecoli_08_20_16_43623_ch101_read544_strand1.fast5",
            "miten_PC_20160820_FNFAD20259_MN17223_sequencing_run_AMS_158_R9_WGA_Ecoli_08_20_16_43623_ch103_read333_strand1.fast5"
        ]
        cls.fast5_paths = list_dir(cls.fast5_dir, ext="fast5")
        cls.fast5_bam = os.path.join(
            cls.HOME,
            "tests/minion_test_reads/canonical_ecoli_R9/canonical_ecoli.bam")
        cls.fast5_readdb = os.path.join(
            cls.HOME,
            "tests/minion_test_reads/canonical_ecoli_R9/canonical_ecoli.readdb"
        )

        cls.template_hmm = os.path.join(
            cls.HOME, "models/testModelR9_acgt_template.model")
        cls.path_to_bin = os.path.join(cls.HOME, 'bin')
        cls.tmp_directory = tempfile.mkdtemp()
        cls.test_dir = os.path.join(cls.tmp_directory, "test")

        dna_dir = os.path.join(cls.HOME, "tests/minion_test_reads/1D/")
        # copy file to tmp directory
        shutil.copytree(dna_dir, cls.test_dir)
        cls.readdb = os.path.join(
            cls.HOME, "tests/minion_test_reads/oneD.fastq.index.readdb")
        cls.bam = os.path.join(cls.HOME, "tests/minion_test_reads/oneD.bam")

        cls.rna_bam = os.path.join(
            cls.HOME, "tests/minion_test_reads/RNA_edge_cases/rna_reads.bam")
        cls.rna_readdb = os.path.join(
            cls.HOME,
            "tests/minion_test_reads/RNA_edge_cases/rna_reads.readdb")
        cls.test_dir_rna = os.path.join(cls.tmp_directory, "test_rna")
        cls.rna_reference = os.path.join(
            cls.HOME, "tests/test_sequences/fake_rna_ref.fa")

        rna_dir = os.path.join(cls.HOME,
                               "tests/minion_test_reads/RNA_edge_cases/")
        # copy file to tmp directory
        shutil.copytree(rna_dir, cls.test_dir_rna)
        # used to test runSignalAlign with config file
        cls.config_file = os.path.join(cls.HOME,
                                       "tests/runSignalAlign-config.json")
        cls.default_args = create_dot_dict(load_json(cls.config_file))
Exemplo n.º 3
0
    def test_create_dot_dict(self):
        with captured_output() as (_, _):
            dict1 = {"a": 1, "b": 2, 3: "c", "d": {"e": 4}}
            self.assertRaises(AssertionError, create_dot_dict, dict1)
            self.assertRaises(TypeError, create_dot_dict, [1, 2, 3])

            dict1 = {"a": 1, "b": 2, "d": {"e": 4}, "f": [{"g": 5, "h": 6}]}
            dict2 = create_dot_dict(dict1)
            self.assertEqual(dict2.a, 1)
            self.assertEqual(dict2.b, 2)
            self.assertEqual(dict2.d, {"e": 4})
            self.assertEqual(dict2.d.e, 4)
            self.assertEqual(dict2.f[0].g, 5)
Exemplo n.º 4
0
def main(config=None):
    """Plot event to reference labelled ONT nanopore reads"""
    start = timer()
    if config is None:
        args = parse_args()
        # load model files
        assert os.path.exists(
            args.config), "Config file does not exist: {}".format(args.config)
        config = load_json(args.config)

    args = create_dot_dict(config)
    # get assignments and load model
    try:
        assignments = parse_assignment_file(args.assignments)
    except ValueError:
        assignments = parse_alignment_file(args.assignments)

    model_h = HmmModel(args.model_path, rna=args.rna)
    target_model = None
    if args.target_hmm_model is not None:
        target_model = HmmModel(args.target_hmm_model,
                                hdp_model_file=args.target_hdp_model,
                                rna=args.rna)
    # generate kmers to match
    all_kmer_pairs = set()
    for motif in args.motifs:
        all_kmer_pairs |= set(
            tuple(row) for row in get_motif_kmer_pairs(motif_pair=motif,
                                                       k=model_h.kmer_length))

    data = generate_gaussian_mixture_model_for_motifs(
        model_h,
        assignments,
        all_kmer_pairs,
        args.strand,
        args.output_dir,
        plot=args.plot,
        name="ccwgg",
        target_model=target_model,
        show=args.show)
    # data = pd.read_csv(os.path.join(args.output_dir, "t_distances.tsv"), delimiter="\t")
    # data = data.ix[0]
    # plot_mixture_model_distribution(data["kmer"], data["canonical_model_mean"], data["canonical_model_sd"],
    #                                 data["canonical_mixture_mean"],
    #                                 data["canonical_mixture_sd"], data["modified_mixture_mean"],
    #                                 data["modified_mixture_sd"],
    #                                 data["strand"], kmer_assignments=assignments, save_fig_dir=None)
    stop = timer()
    print("Running Time = {} seconds".format(stop - start), file=sys.stderr)
def main(config=None):
    """Plot number of breaks in a read over some difference and plot compared to accuracy of the variant called read"""
    start = timer()
    if config is None:
        args = parse_args()
        assert os.path.exists(
            args.config), "Config file does not exist: {}".format(args.config)
        config = load_json(args.config)

    args = create_dot_dict(config)
    threshold = args.gap_threshold
    names = []
    all_event_lengths = []
    all_skips = []
    all_percent_correct = []
    data = []
    for experiment in args.plot:
        names.append(experiment.name)
        for sample in experiment.samples:
            data = multiprocess_get_gaps_from_reads(
                sample.embedded_fast5_dir,
                experiment.sa_number,
                label=sample.label,
                gap_threshold=threshold,
                worker_count=7,
                alignment_threshold=args.threshold,
                debug=False)
        all_skips.append([x[0] for x in data])
        all_event_lengths.append([x[1] for x in data])
        all_percent_correct.append([x[2] for x in data])
    plot_n_breaks_vs_n_events_scatter_plot(all_event_lengths,
                                           all_skips,
                                           names,
                                           all_percent_correct,
                                           save_fig_path=os.path.join(
                                               args.save_fig_dir,
                                               "n_breaks_vs_n_events.png"))
    plot_n_breaks_ratio_vs_accuracy_scatter_plot(
        all_event_lengths,
        all_skips,
        names,
        all_percent_correct,
        save_fig_path=os.path.join(args.save_fig_dir,
                                   "n_breaks_ratio_vs_accuracy.png"))
    stop = timer()
    print("Running Time = {} seconds".format(stop - start), file=sys.stderr)
Exemplo n.º 6
0
def main(args):
    # parse args
    start = timer()

    args = parse_args()
    if args.command == "run":
        if not os.path.exists(args.config):
            print("{config} not found".format(config=args.config))
            exit(1)
        # run training
        config_args = create_dot_dict(load_json(args.config))

        temp_folder = FolderHandler()
        temp_dir_path = temp_folder.open_folder(
            os.path.join(os.path.abspath(config_args.output_dir),
                         "tempFiles_alignment"))
        temp_dir_path = resolvePath(temp_dir_path)
        print(config_args.output_dir)
        print(temp_dir_path)

        sa_args = [
            merge_dicts([
                s, {
                    "quality_threshold": config_args.filter_reads,
                    "workers": config_args.job_count
                }
            ]) for s in config_args.samples
        ]

        samples = [
            SignalAlignSample(working_folder=temp_folder, **s) for s in sa_args
        ]
        copyfile(args.config,
                 os.path.join(temp_dir_path, os.path.basename(args.config)))

        state_machine_type = "threeState"
        if config_args.template_hdp_model_path is not None:
            state_machine_type = "threeStateHdp"

        alignment_args = create_signalAlignment_args(
            destination=temp_dir_path,
            stateMachineType=state_machine_type,
            in_templateHmm=resolvePath(config_args.template_hmm_model),
            in_complementHmm=resolvePath(config_args.complement_hmm_model),
            in_templateHdp=resolvePath(config_args.template_hdp_model),
            in_complementHdp=resolvePath(config_args.complement_hdp_model),
            diagonal_expansion=config_args.diagonal_expansion,
            constraint_trim=config_args.constraint_trim,
            traceBackDiagonals=config_args.traceBackDiagonals,
            twoD_chemistry=config_args.two_d,
            get_expectations=False,
            path_to_bin=resolvePath(config_args.path_to_bin),
            check_for_temp_file_existance=True,
            threshold=config_args.signal_alignment_args.threshold,
            track_memory_usage=config_args.signal_alignment_args.
            track_memory_usage,
            embed=config_args.signal_alignment_args.embed,
            event_table=config_args.signal_alignment_args.event_table,
            output_format=config_args.signal_alignment_args.output_format,
            filter_reads=config_args.filter_reads,
            delete_tmp=config_args.signal_alignment_args.delete_tmp)

        multithread_signal_alignment_samples(samples,
                                             alignment_args,
                                             config_args.job_count,
                                             trim=None,
                                             debug=config_args.debug)

        print("\n#  signalAlign - finished alignments\n", file=sys.stderr)
        print("\n#  signalAlign - finished alignments\n", file=sys.stdout)
        stop = timer()
    else:
        command_line = " ".join(sys.argv[:])
        print(os.getcwd())

        print("Command Line: {cmdLine}\n".format(cmdLine=command_line),
              file=sys.stderr)
        # get absolute paths to inputs
        args.files_dir = resolvePath(args.files_dir)
        args.forward_reference = resolvePath(args.forward_ref)
        args.backward_reference = resolvePath(args.backward_ref)
        args.out = resolvePath(args.out)
        args.bwa_reference = resolvePath(args.bwa_reference)
        args.in_T_Hmm = resolvePath(args.in_T_Hmm)
        args.in_C_Hmm = resolvePath(args.in_C_Hmm)
        args.templateHDP = resolvePath(args.templateHDP)
        args.complementHDP = resolvePath(args.complementHDP)
        args.fofn = resolvePath(args.fofn)
        args.target_regions = resolvePath(args.target_regions)
        args.ambiguity_positions = resolvePath(args.ambiguity_positions)
        args.alignment_file = resolvePath(args.alignment_file)
        start_message = """
    #   Starting Signal Align
    #   Aligning files from: {fileDir}
    #   Aligning to reference: {reference}
    #   Aligning maximum of {nbFiles} files
    #   Using model: {model}
    #   Using banding: True
    #   Aligning to regions in: {regions}
    #   Non-default template HMM: {inThmm}
    #   Non-default complement HMM: {inChmm}
    #   Template HDP: {tHdp}
    #   Complement HDP: {cHdp}
        """.format(fileDir=args.files_dir,
                   reference=args.bwa_reference,
                   nbFiles=args.nb_files,
                   inThmm=args.in_T_Hmm,
                   inChmm=args.in_C_Hmm,
                   model=args.stateMachineType,
                   regions=args.target_regions,
                   tHdp=args.templateHDP,
                   cHdp=args.complementHDP)

        print(start_message, file=sys.stdout)

        if args.files_dir is None and args.fofn is None:
            print("Need to provide directory with .fast5 files of fofn",
                  file=sys.stderr)
            sys.exit(1)

        if not os.path.isfile(args.bwa_reference):
            print("Did not find valid reference file, looked for it {here}".
                  format(here=args.bwa_reference),
                  file=sys.stderr)
            sys.exit(1)

        # make directory to put temporary files
        if not os.path.isdir(args.out):
            print("Creating output directory: {}".format(args.out),
                  file=sys.stdout)
            os.mkdir(args.out)
        temp_folder = FolderHandler()
        temp_dir_path = temp_folder.open_folder(
            os.path.join(os.path.abspath(args.out), "tempFiles_alignment"))
        temp_dir_path = resolvePath(temp_dir_path)
        print(args.out)
        print(temp_dir_path)

        # generate reference sequence if not specified
        if not args.forward_reference or not args.backward_reference:
            args.forward_reference, args.backward_reference = processReferenceFasta(
                fasta=args.bwa_reference,
                work_folder=temp_folder,
                positions_file=args.ambiguity_positions,
                name="")

        # list of read files
        if args.fofn is not None:
            fast5s = [x for x in parseFofn(args.fofn) if x.endswith(".fast5")]
        else:
            fast5s = [
                "/".join([args.files_dir, x])
                for x in os.listdir(args.files_dir) if x.endswith(".fast5")
            ]

        nb_files = args.nb_files
        if nb_files < len(fast5s):
            shuffle(fast5s)
            fast5s = fast5s[:nb_files]

        # return alignment_args
        alignment_args = {
            "destination": temp_dir_path,
            "stateMachineType": args.stateMachineType,
            "bwa_reference": args.bwa_reference,
            "in_templateHmm": args.in_T_Hmm,
            "in_complementHmm": args.in_C_Hmm,
            "in_templateHdp": args.templateHDP,
            "in_complementHdp": args.complementHDP,
            "output_format": args.outFmt,
            "threshold": args.threshold,
            "diagonal_expansion": args.diag_expansion,
            "constraint_trim": args.constraint_trim,
            "degenerate": getDegenerateEnum(args.degenerate),
            "twoD_chemistry": args.twoD,
            "target_regions": args.target_regions,
            "embed": args.embed,
            "event_table": args.event_table,
            "backward_reference": args.backward_reference,
            "forward_reference": args.forward_reference,
            "alignment_file": args.alignment_file,
            "check_for_temp_file_existance": True,
            "track_memory_usage": False,
            "get_expectations": False,
            "perform_kmer_event_alignment": args.perform_kmer_event_alignment,
            "enforce_supported_versions": args.enforce_supported_versions,
            "filter_reads": 7 if args.filter_reads else None,
            "path_to_bin": args.path_to_bin,
            "delete_tmp": args.delete_tmp
        }
        filter_read_generator = None
        if args.filter_reads is not None and args.alignment_file and args.readdb and args.files_dir:
            print("[runSignalAlign]:NOTICE: Filtering out low quality reads",
                  file=sys.stdout)

            filter_read_generator = filter_reads_to_string_wrapper(
                filter_reads(args.alignment_file,
                             args.readdb, [args.files_dir],
                             quality_threshold=7,
                             recursive=args.recursive))

        print("[runSignalAlign]:NOTICE: Got {} files to align".format(
            len(fast5s)),
              file=sys.stdout)
        # setup workers for multiprocessing
        multithread_signal_alignment(
            alignment_args,
            fast5s,
            args.nb_jobs,
            debug=args.DEBUG,
            filter_reads_to_string_wrapper=filter_read_generator)
        stop = timer()

        print("\n#  signalAlign - finished alignments\n", file=sys.stderr)
        print("\n#  signalAlign - finished alignments\n", file=sys.stdout)

    print("[signalAlign] Complete")
    print("Running Time = {} seconds".format(stop - start))
Exemplo n.º 7
0
def plot_roc_from_config(config):
    """Plotting function to handle logic of the config file. Mainly created to test function"""
    config = create_dot_dict(config)

    variants = config.variants
    samples = config.samples
    if isinstance(config.threshold, float):
        threshold = config.threshold
    else:
        threshold = 0.500000001

    if isinstance(config.jobs, int):
        n_processes = config.jobs
    else:
        n_processes = 2

    save_fig_dir = config.save_fig_dir

    assert len(samples) > 0, "Must include samples in order to do comparison"
    aor_handles = []
    gwa_lables_list = []
    per_site_label_list = []
    plot_per_read = False
    plot_genome_position_aggregate = False
    plot_per_call = False

    # process samples
    for sample in samples:
        tsvs = sample.full_tsvs
        positions = sample.positions_file
        label = sample.label
        aor_h = AggregateOverReadsFull(tsvs, variants, verbose=True, processes=n_processes)
        aor_h.marginalize_over_all_reads()
        aor_handles.append(aor_h)
        assert positions or label, "Must provide either a label: {} or a positions file: {}".format(label,
                                                                                                    positions)
        # use character as label if given
        if label:
            plot_genome_position_aggregate = True
            plot_per_call = True
            plot_per_read = True
            aor_h.aggregate_position_probs = aor_h.generate_labels2(predicted_data=aor_h.aggregate_position_probs,
                                                                    true_char=label)
            aor_h.per_read_data = aor_h.generate_labels2(predicted_data=aor_h.per_read_data,
                                                         true_char=label)
            aor_h.per_position_data = aor_h.generate_labels2(predicted_data=aor_h.per_position_data,
                                                             true_char=label)

        # if positions file is given, check accuracy from that
        elif positions:
            plot_genome_position_aggregate = True
            plot_per_call = True

            genome_position_labels = CustomAmbiguityPositions.parseAmbiguityFile(positions)
            aor_h.aggregate_position_probs = aor_h.generate_labels(labelled_positions=genome_position_labels,
                                                                   predicted_data=aor_h.aggregate_position_probs)
            aor_h.per_position_data = aor_h.generate_labels(labelled_positions=genome_position_labels,
                                                            predicted_data=aor_h.per_position_data)

    # plot per read ROC curve
    if plot_per_read:
        all_per_read_labels = pd.concat([x.per_read_data for x in aor_handles], ignore_index=True)
        data_type_name = "per_read"
        plot_all_roc_curves(all_per_read_labels, variants, save_fig_dir, data_type_name, threshold=threshold)

    # plot per call ROC curve
    if plot_per_call:
        all_site_labels = pd.concat([x.per_position_data for x in aor_handles], ignore_index=True)
        data_type_name = "per_site_per_read"
        plot_all_roc_curves(all_site_labels, variants, save_fig_dir, data_type_name, threshold=threshold)

    # plot genome position calls
    if plot_genome_position_aggregate:
        all_genome_positions_labels = pd.concat([x.aggregate_position_probs for x in aor_handles], ignore_index=True)
        data_type_name = "per_genomic_site"
        plot_all_roc_curves(all_genome_positions_labels, variants, save_fig_dir, data_type_name, label_key="contig",
                            threshold=threshold)

    return 0
def main(config=None):
    """Plot event to reference labelled ONT nanopore reads"""
    start = timer()
    if config is None:
        args = parse_args()
        # load model files
        assert os.path.exists(args.config), "Config file does not exist: {}".format(args.config)
        config = load_json(args.config)

    args = create_dot_dict(config)
    threshold = args.threshold
    all_data = []
    names = []
    for experiment in args.plot:
        names.append(experiment.name)
        experiment_data = []
        for sample in experiment.samples:
            tsvs = None
            f5s = None
            if sample.variant_tsvs is not None:
                tsvs = list_dir(sample.variant_tsvs, ext="vc.tsv")
            if sample.embedded_fast5_dir is not None:
                f5s = list_dir(sample.embedded_fast5_dir, ext="fast5")

            data = multiprocess_get_distance_from_guide(sample.embedded_fast5_dir, experiment.sa_number,
                                                                     threshold, sample.label,
                                                                     worker_count=7, debug=False)

            experiment_data.extend([x for x in data if x is not None])
        all_data.append(pd.concat(experiment_data))



    true_deltas = pd.concat([data[data["true_false"]]["guide_delta"] for data in all_data])
    true_starts = pd.concat([data[data["true_false"]]["raw_start"] for data in all_data])

    false_deltas = pd.concat([data[[not x for x in data["true_false"]]]["guide_delta"] for data in all_data])
    false_starts = pd.concat([data[[not x for x in data["true_false"]]]["raw_start"] for data in all_data])

    plot_deviation_vs_time_from_start([true_deltas, false_deltas], [true_starts, false_starts], ["True", "False"],
                                      os.path.join(args.save_fig_dir, "raw_start_vs_alignment_deviation_accuracy.png"))

    plot_deviation_vs_time_from_start([x["guide_delta"] for x in all_data], [x["raw_start"] for x in all_data], names,
                                      os.path.join(args.save_fig_dir, "raw_start_vs_alignment_deviation.png"))

    new_names = []
    new_data = []
    for name, data in zip(names, all_data):
        new_data.append(data[data["true_false"]]["guide_delta"])
        new_names.append(name+"_correct")
        new_data.append(data[[not x for x in data["true_false"]]]["guide_delta"])
        new_names.append(name+"_wrong")

    plot_alignment_deviation(new_data, new_names, bins=np.arange(-10000, 10000, 100),
                             save_fig_path=os.path.join(args.save_fig_dir, "alignment_deviation_hist.png"))

    plot_violin_classication_alignment_deviation(new_data, new_names,
                                                 save_fig_path=os.path.join(args.save_fig_dir,
                                                                            "alignment_deviation_violin.png"))

    plot_classification_accuracy_vs_deviation(all_data, names,
                                              save_fig_path=os.path.join(args.save_fig_dir,
                                                                         "classification_accuracy_vs_deviation.png"))

    stop = timer()
    print("Running Time = {} seconds".format(stop - start), file=sys.stderr)