Пример #1
0
def run_downloaddata(args):

    logger = TobiasLogger("DownloadData", args.verbosity)
    logger.begin()

    #Create config dict from commandline
    config = {
        "endpoint": args.endpoint,
        "username": args.username,
        "accesskey": args.key,
        "buckets": {
            args.bucket: args.patterns
        }
    }

    #Overwrite if args.yaml is set:
    if args.yaml is not None:
        config_from_yaml = read_config_yaml(args.yaml)
        for key in config:
            config[key] = config_from_yaml[key]

    logger.debug("Configuration dict is: {0}".format(config))

    #Download data using s3 client
    s3_client(config, logger, args.force)

    logger.end()
Пример #2
0
def count_reads(regions_list, params):
    """ Count reads from bam within regions (counts position of cutsite to prevent double-counting) """

    bam_f = params.bam
    read_shift = params.read_shift
    bam_obj = pysam.AlignmentFile(bam_f, "rb")

    log_q = params.log_q
    logger = TobiasLogger("", params.verbosity,
                          log_q)  #sending all logger calls to log_q

    #Count per region
    read_count = 0
    logger.spam("Started counting region_chunk ({0} -> {1})".format(
        "_".join([str(element) for element in regions_list[0]]),
        "_".join([str(element) for element in regions_list[-1]])))

    for region in regions_list:
        read_lst = ReadList().from_bam(bam_obj, region)
        logger.spam("- {0} ({1} reads)".format(region, len(read_lst)))
        for read in read_lst:
            read.get_cutsite(read_shift)
            if read.cutsite > region.start and read.cutsite <= region.end:  #only reads within borders
                read_count += 1

    logger.spam("Finished counting region_chunk ({0} -> {1})".format(
        "_".join([str(element) for element in regions_list[0]]),
        "_".join([str(element) for element in regions_list[-1]])))
    bam_obj.close()

    return (read_count)
Пример #3
0
    def get_signal(self, pybw, numpy_bool=True, logger=TobiasLogger()):
        """ Get signal from bigwig in region """

        try:
            #Define whether pybigwig was compiled with numpy
            if pyBigWig.numpy == 1:
                values = pybw.values(self.chrom,
                                     self.start,
                                     self.end,
                                     numpy=numpy_bool)
            else:
                values = np.array(pybw.values(
                    self.chrom, self.start,
                    self.end))  #fetch list of values and convert to numpy arr
            values = np.nan_to_num(values)  #nan to 0

            if self.strand == "-":
                signal = values[::-1]
            else:
                signal = values

        except Exception as e:
            logger.error(
                "Error reading region: {0} from pybigwig object. Exception is: {1}"
                .format(self.tup(), e))
            traceback.print_tb(e.__traceback__)
            raise e

        return (signal)
Пример #4
0
    def from_bed(self, bedfile_f, logger=TobiasLogger()):
        """ Initialize Object from bedfile """

        #Read all lines
        bedlines = open(bedfile_f).readlines()
        self = RegionList([None] * len(bedlines))
        for i, line in enumerate(bedlines):

            if line.startswith("#"):  #comment lines are excluded
                continue

            #Test line format
            if re.match(r"[^\s]+\t\d+\t\d+.", line) == None:
                logger.error(
                    "Line {0} in {1} is not proper bed format:\n{2}".format(
                        i + 1, bedfile_f, line))
                sys.exit()

            columns = line.rstrip().split("\t")
            columns[1] = int(columns[1])  #start
            columns[2] = int(columns[2])  #end

            if columns[1] >= columns[2]:
                logger.error(
                    "Line {0} in {1} is not proper bed format:\n{2}".format(
                        i + 1, bedfile_f, line))
                sys.exit()

            region = OneRegion(columns)
            self[i] = region

        return (self)
Пример #5
0
    def check_boundary(self,
                       boundaries_dict,
                       action="cut",
                       logger=TobiasLogger()):
        """ Check if region is within chromosome boundaries. Actions:
				- "cut": cut region to bounds. If the chromosome is not in boundaries_dict, "cut" falls back on "remove"
				- "remove": remove region outside bounds (returns None)
				- "exit": exit the program with error message through logger
		"""

        #Establish if region is outside of bounds:
        outside = 0
        if self.chrom not in boundaries_dict:
            if action == "exit":
                logger.error(
                    "Chromosome for region \"{0}\" is not found in list of available chromosomes ({1})"
                    .format(self, list(boundaries_dict.keys())))
                sys.exit(1)

            self = None  #cannot cut to bounds when boundaries are not known; remove
            return (self)

        elif self.start < 0:
            outside = 1
        elif self.end > int(boundaries_dict[self.chrom]):
            outside = 1

        #Perform action if region is outside of bounds
        if outside == 1:
            if action == "cut":
                self.start = max([0, self.start])
                self.end = min([boundaries_dict[self.chrom], self.end])

                #Update positions in list
                self[1] = self.start
                self[2] = self.end

                #If the region has been cut to be 0 of less length; remove
                if self.get_length() <= 0:
                    self = None

            elif action == "remove":
                self = None

            elif action == "exit":
                logger.error(
                    "Region \"{0}\" is outside of the chromosome boundaries ({1}: {2})"
                    .format(self, self.chrom, boundaries_dict[self.chrom]))
                sys.exit(1)

        return (self)
Пример #6
0
    def from_bed(self, bedfile_f, logger=TobiasLogger()):
        """ Initialize Object from bedfile """

        #Read all lines
        bedlines = open(bedfile_f).readlines()
        self = RegionList(
            [None] *
            len(bedlines))  #intialize to prevent appending for large bedfiles
        for i, line in enumerate(bedlines):

            if line.startswith("#"):  #comment lines are excluded
                continue

            #Test line format
            if re.match(r"[^\s]+\t\d+\t\d+\b.*", line) == None:
                logger.error(
                    "Line {0} in {1} is not proper bed format:\n{2}".format(
                        i + 1, bedfile_f, line))
                sys.exit(1)

            columns = line.split("\t")  #split (and then .rstrip() afterwards)
            columns[-1] = columns[-1].rstrip(
            )  # remove line-ending from last col
            columns[1] = int(columns[1])  #start
            columns[2] = int(columns[2])  #end

            if columns[1] >= columns[2]:
                logger.error(
                    "Start position is larger than end position in line {0} in {1}:\n{2}"
                    .format(i + 1, bedfile_f, line))
                sys.exit(1)

            region = OneRegion(columns)
            self[i] = region

        #Remove any None's in list
        for i in range(
                len(self) - 1, -1, -1
        ):  #counting down so that index in list does not change as None's are removed
            if self[i] is None:
                del self[i]

        return (self)
Пример #7
0
def run_plotchanges(args):

	#------------------------------------ Get ready ------------------------------------#
	logger = TobiasLogger("PlotChanges", args.verbosity)
	logger.begin()

	check_required(args, ["bindetect"])
	check_files([args.bindetect, args.TFS], "r")
	check_files([args.output], "w")

	#------------------------------------ Read data ------------------------------------#

	logger.info("Reading data from bindetect file")

	# Read in bindetect file
	bindetect = pd.read_csv(args.bindetect, sep="\t")
	bindetect.set_index("output_prefix", inplace=True, drop=False)
	
	all_TFS = list(bindetect["output_prefix"])
	logger.info("{0} TFS found in bindetect file".format(len(all_TFS)))

	#Read in TF names from --TFS:
	if args.TFS != None:
	
		given_TFS = open(args.TFS, "r").read().split()
		logger.info("TFS given in --TFS: {0}".format(given_TFS))

		#Find matches between all and given
		logger.info("Matching given TFs with bindetect file...")
		lofl = [given_TFS, all_TFS]
		matches = match_lists(lofl)

		for i, TF in enumerate(given_TFS):
			logger.info("- {0} matched with: {1}".format(TF, matches[0][i]))
		
		#Get tfs
		chosen_TFS = list(set(flatten_list(matches)))
		logger.info("Chosen TFS to view in plot: {0}".format(chosen_TFS))
	else:
		logger.info("Showing all TFS in plot. Please use --TFS to subset output.")
		chosen_TFS = all_TFS

	# Get order of conditions
	header = list(bindetect.columns.values)
	conditions_file = [element.replace("_bound", "") for element in header if "bound" in element]
	if args.conditions == None:	
		args.conditions = conditions_file
	else:
		if not all([z in conditions_file for z in args.conditions]):
			logger.info("ERROR: --conditions {0} is not a subset of bindetect conditions ({1})".format(args.conditions, conditions_file))
			sys.exit()
		
	logger.info("Conditions in order: {0}".format(args.conditions))

	#------------------------------------ Make plots --------------------------------#

	logger.info("Plotting figure")

	fig_out = os.path.abspath(args.output)
	figure_pdf = PdfPages(fig_out, keep_empty=True)

	#Changes over time for different measures
	for cluster_flag in [False, True]:
		#logger.info("- Use clusters: {0}".format(cluster_flag))

		#Choose whether to show individual TFs or clusters
		if cluster_flag == True:
			table = bindetect.loc[chosen_TFS,].groupby("cluster").mean() #mean of each column
		else:
			table = bindetect.loc[chosen_TFS]
		
		#Get colors ready
		cmap = matplotlib.cm.get_cmap('rainbow')
		colors = cmap(np.linspace(0,1,len(table)))

		xvals = np.arange(0,len(args.conditions))
		for measure in ["n_bound", "percent_bound", "mean_score"]:
			#logger.info("-- {0}".format(measure))
			fig, ax = plt.subplots(figsize=(10,5))
			for i, TF in enumerate(table.index):

				if measure == "n_bound":
					yvals = np.array([table.at[TF, "{0}_bound".format(cond)] for cond in args.conditions])
				elif measure == "percent_bound":
					n_bound = np.array([table.at[TF, "{0}_bound".format(cond)] for cond in args.conditions])
					yvals = n_bound / table.at[TF, "total_tfbs"] * 100.0	#percent bound
				elif measure == "mean_score":
					yvals = np.array([table.at[TF, "{0}_mean_score".format(cond)] for cond in args.conditions])

				ax.plot(xvals, yvals, color=colors[i], marker="o", label=TF)
				ax.annotate(TF, (xvals[0]-0.1, yvals[0]), color=colors[i], horizontalalignment="right", verticalalignment="center", fontsize=6)

			#General
			plt.title("Changes in TF binding across conditions")
			plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=3, markerscale=0.5)
			plt.xticks(xvals, args.conditions)
			plt.xlabel("Conditions")
			
			if measure == "n_bound":
				plt.ylabel("Number of sites predicted bound", color="black")
			elif measure == "percent_bound":
				plt.ylabel("Percent of sites predicted bound", color="black")
			elif measure == "mean_score":
				plt.ylabel("Mean binding score", color="black")
			ax.tick_params('y', colors='black')

			plt.xlim(xvals[0]-2, xvals[-1]+0.5)
			figure_pdf.savefig(fig, bbox_inches='tight')
			plt.close()


	#Change between conditions
	#condition_comparison = list(itertools.combinations(args.conditions, 2))
	"""
		diffs = []
		for (cond1, cond2) in condition_comparison:
			try:
				diffs.append(-table.at[TF, "{0}_{1}_change".format(cond1, cond2)])		#positive means cond1 > cond2, meaning cond1->cond2 change should be negated
			except KeyError:
				diffs.append(table.at[TF, "{1}_{0}_change".format(cond1, cond2)])

		diffs = np.cumsum(diffs)
	
	ax2 = ax1.twinx()

	xvals_shift = np.arange(0.5,len(condition_comparison),1)
	ax2.plot(xvals_shift, diffs, color="r", marker="o")
	ax2.set_ylabel("Difference between conditions", color="r", rotation=270)
	ax2.tick_params('y', colors='r')
	"""

	figure_pdf.close()
	logger.end()
	logger.info("Saved figure to {0}".format(args.output))
Пример #8
0
def run_formatmotifs(args):

    check_required(args, ["input", "output"])  #Check input arguments
    motif_files = expand_dirs(args.input)  #Expand any dirs in input
    check_files(motif_files + [args.filter])  #Check if files exist

    # Create logger and write argument overview
    logger = TobiasLogger("FormatMotifs", args.verbosity)
    logger.begin()

    parser = add_formatmotifs_arguments(argparse.ArgumentParser())
    logger.arguments_overview(parser, args)
    logger.output_files([args.output])

    ####### Getting ready #######
    if args.task == "split":
        logger.info("Making directory {0} if not existing".format(args.output))
        make_directory(args.output)  #Check and create output directory

    ### Read motifs from files ###
    logger.info("Reading input files...")
    motif_list = MotifList()
    converted_content = ""
    for f in motif_files:
        logger.debug("- {0}".format(f))
        motif_list.extend(MotifList().from_file(f))

    logger.info("Read {} motifs\n".format(len(motif_list)))

    #Sort out duplicate motifs
    all_motif_ids = [motif.id for motif in motif_list]
    unique_motif_ids = set(all_motif_ids)
    if len(all_motif_ids) != len(unique_motif_ids):
        logger.info(
            "Found duplicate motif ids in file - choosing first motif with unique id."
        )
        motif_list = MotifList([
            motif_list[all_motif_ids.index(motifid)]
            for motifid in unique_motif_ids
        ])
        logger.info("Reduced to {0} unique motif ids".format(len(motif_list)))

    ### Filter motif list ###
    if args.filter != None:

        #Read filter
        logger.info("Reading entries in {0}".format(args.filter))
        entries = open(args.filter, "r").read().split()
        logger.info("Read {0} unique filter values".format(len(set(entries))))

        #Match to input motifs #print(entries)
        logger.info("Matching motifs to filter")
        used_filters = []
        filtered_list = MotifList()
        for input_motif in motif_list:
            found_in_filter = 0
            i = -1
            while found_in_filter == 0 and i < len(entries) - 1:
                i += 1
                if entries[i].lower() in input_motif.name.lower(
                ) or entries[i].lower() in input_motif.id.lower():
                    filtered_list.append(input_motif)
                    logger.debug(
                        "Selected motif {0} ({1}) due to filter value {2}".
                        format(input_motif.name, input_motif.id, entries[i]))
                    found_in_filter = 1
                    used_filters.append(entries[i])

            if found_in_filter == 0:
                logger.debug(
                    "Could not find any match to motif {0} ({1}) in filter".
                    format(input_motif.name, input_motif.id))

        logger.info("Filtered number of motifs from {0} to {1}".format(
            len(motif_list), len(filtered_list)))
        motif_list = filtered_list

        logger.debug("Filters not used: {0}".format(
            list(set(entries) - set(used_filters))))

    #### Write out results ####
    if args.task == "split":
        logger.info("Writing individual files to directory {0}".format(
            args.output))

        for motif in motif_list:
            motif_string = MotifList([motif]).as_string(args.format)

            #Open file and write
            out_path = os.path.join(args.output, motif.id + "." + args.format)
            logger.info("- {0}".format(out_path))
            f_out = open(out_path, "w")
            f_out.write(motif_string)
            f_out.close()

    elif args.task == "join":
        logger.info("Writing converted motifs to file {0}".format(args.output))

        f_out = open(args.output, "w")
        motif_string = motif_list.as_string(args.format)
        f_out.write(motif_string)
        f_out.close()

    logger.end()
Пример #9
0
def run_atacorrect(args):
    """
	Function for bias correction of input .bam files
	Calls functions in ATACorrect_functions and several internal classes
	"""

    #Test if required arguments were given:
    if args.bam == None:
        sys.exit("Error: No .bam-file given")
    if args.genome == None:
        sys.exit("Error: No .fasta-file given")
    if args.peaks == None:
        sys.exit("Error: No .peaks-file given")

    #Adjust some parameters depending on input
    args.prefix = os.path.splitext(os.path.basename(
        args.bam))[0] if args.prefix == None else args.prefix
    args.outdir = os.path.abspath(
        args.outdir) if args.outdir != None else os.path.abspath(os.getcwd())

    #Set output bigwigs based on input
    tracks = ["uncorrected", "bias", "expected", "corrected"]
    tracks = [track for track in tracks
              if track not in args.track_off]  # switch off printing

    if args.split_strands == True:
        strands = ["forward", "reverse"]
    else:
        strands = ["both"]

    output_bws = {}
    for track in tracks:
        output_bws[track] = {}
        for strand in strands:
            elements = [args.prefix, track] if strand == "both" else [
                args.prefix, track, strand
            ]
            output_bws[track][strand] = {
                "fn":
                os.path.join(args.outdir, "{0}.bw".format("_".join(elements)))
            }

    #Set all output files
    bam_out = os.path.join(args.outdir, args.prefix + "_atacorrect.bam")
    bigwigs = [
        output_bws[track][strand]["fn"]
        for (track, strand) in itertools.product(tracks, strands)
    ]
    figures_f = os.path.join(args.outdir,
                             "{0}_atacorrect.pdf".format(args.prefix))

    output_files = bigwigs + [figures_f]
    output_files = list(OrderedDict.fromkeys(
        output_files))  #remove duplicates due to "both" option

    strands = ["forward", "reverse"]

    #----------------------------------------------------------------------------------------------------#
    # Print info on run
    #----------------------------------------------------------------------------------------------------#

    logger = TobiasLogger("ATACorrect", args.verbosity)
    logger.begin()

    parser = add_atacorrect_arguments(argparse.ArgumentParser())
    logger.arguments_overview(parser, args)
    logger.output_files(output_files)

    args.cores = check_cores(args.cores, logger)

    #----------------------------------------------------------------------------------------------------#
    # Test input file availability for reading
    #----------------------------------------------------------------------------------------------------#

    logger.info("----- Processing input data -----")

    logger.debug("Testing input file availability")
    check_files([args.bam, args.genome, args.peaks], "r")

    logger.debug("Testing output directory/file writeability")
    make_directory(args.outdir)
    check_files(output_files, "w")

    #Open pdf for figures
    figure_pdf = PdfPages(figures_f, keep_empty=False)

    #----------------------------------------------------------------------------------------------------#
    # Read information in bam/fasta
    #----------------------------------------------------------------------------------------------------#

    logger.info("Reading info from .bam file")
    bamfile = pysam.AlignmentFile(args.bam, "rb")
    if bamfile.has_index() == False:
        logger.warning("No index found for bamfile - creating one via pysam.")
        pysam.index(args.bam)

    bam_references = bamfile.references  #chromosomes in correct order
    bam_chrom_info = dict(zip(bamfile.references, bamfile.lengths))
    logger.debug("bam_chrom_info: {0}".format(bam_chrom_info))
    bamfile.close()

    logger.info("Reading info from .fasta file")
    fastafile = pysam.FastaFile(args.genome)
    fasta_chrom_info = dict(zip(fastafile.references, fastafile.lengths))
    logger.debug("fasta_chrom_info: {0}".format(fasta_chrom_info))
    fastafile.close()

    #Compare chrom lengths
    chrom_in_common = set(bam_chrom_info.keys()).intersection(
        fasta_chrom_info.keys())
    for chrom in chrom_in_common:
        bamlen = bam_chrom_info[chrom]
        fastalen = fasta_chrom_info[chrom]
        if bamlen != fastalen:
            logger.warning("(Fastafile)\t{0} has length {1}".format(
                chrom, fasta_chrom_info[chrom]))
            logger.warning("(Bamfile)\t{0} has length {1}".format(
                chrom, bam_chrom_info[chrom]))
            sys.exit(
                "Error: .bam and .fasta have different chromosome lengths. Please make sure the genome file is similar to the one used in mapping."
            )

    #Subset bam_references to those for which there are sequences in fasta
    chrom_not_in_fasta = set(bam_references) - set(fasta_chrom_info.keys())
    if len(chrom_not_in_fasta) > 1:
        logger.warning(
            "The following contigs in --bam did not have sequences in --fasta: {0}. NOTE: These contigs will be skipped in calculation and output."
            .format(chrom_not_in_fasta))

    bam_references = [ref for ref in bam_references if ref in fasta_chrom_info]
    chrom_in_common = [ref for ref in chrom_in_common if ref in bam_references]

    #----------------------------------------------------------------------------------------------------#
    # Read regions from bedfiles
    #----------------------------------------------------------------------------------------------------#

    logger.info("Processing input/output regions")

    #Chromosomes included in analysis
    genome_regions = RegionList().from_list([
        OneRegion([chrom, 0, bam_chrom_info[chrom]])
        for chrom in bam_references if not "M" in chrom
    ])  #full genome length
    chrom_in_common = [chrom for chrom in chrom_in_common if "M" not in chrom]
    logger.debug("CHROMS\t{0}".format("; ".join(
        ["{0} ({1})".format(reg.chrom, reg.end) for reg in genome_regions])))
    genome_bp = sum([region.get_length() for region in genome_regions])

    # Process peaks
    peak_regions = RegionList().from_bed(args.peaks)
    peak_regions.merge()
    for i in range(len(peak_regions) - 1, -1, -1):
        region = peak_regions[i]

        peak_regions[i] = region.check_boundary(
            bam_chrom_info, "cut")  #regions are cut/removed from list
        if peak_regions[i] is None:
            logger.warning(
                "Peak region {0} was removed at it is either out of bounds or not in the chromosomes given in genome/bam."
                .format(region.tup(), i + 1))
            del peak_regions[i]

    nonpeak_regions = deepcopy(genome_regions).subtract(peak_regions)

    # Process specific input regions if given
    if args.regions_in != None:
        input_regions = RegionList().from_bed(args.regions_in)
        input_regions.merge()
        input_regions.apply_method(OneRegion.check_boundary, bam_chrom_info,
                                   "cut")
    else:
        input_regions = nonpeak_regions

    # Process specific output regions
    if args.regions_out != None:
        output_regions = RegionList().from_bed(args.regions_out)
    else:
        output_regions = deepcopy(peak_regions)

    #Extend regions to make sure extend + flanking for window/flank are within boundaries
    flank_extend = args.k_flank + int(args.window / 2.0)
    output_regions.apply_method(OneRegion.extend_reg,
                                args.extend + flank_extend)
    output_regions.merge()
    output_regions.apply_method(OneRegion.check_boundary, bam_chrom_info,
                                "cut")
    output_regions.apply_method(
        OneRegion.extend_reg, -flank_extend
    )  #Cut to needed size knowing that the region will be extended in function

    #Remove blacklisted regions and chromosomes not in common
    blacklist_regions = RegionList().from_bed(
        args.blacklist) if args.blacklist != None else RegionList(
            [])  #fill in with regions from args.blacklist
    regions_dict = {
        "genome": genome_regions,
        "input_regions": input_regions,
        "output_regions": output_regions,
        "peak_regions": peak_regions,
        "nonpeak_regions": nonpeak_regions,
        "blacklist_regions": blacklist_regions
    }
    for sub in [
            "input_regions", "output_regions", "peak_regions",
            "nonpeak_regions"
    ]:
        regions_sub = regions_dict[sub]
        regions_sub.subtract(blacklist_regions)
        regions_sub = regions_sub.apply_method(OneRegion.split_region, 50000)

        regions_sub.keep_chroms(chrom_in_common)
        regions_dict[sub] = regions_sub

    #write beds to look at in igv
    #input_regions.write_bed(os.path.join(args.outdir, "input_regions.bed"))
    #output_regions.write_bed(os.path.join(args.outdir, "output_regions.bed"))
    #peak_regions.write_bed(os.path.join(args.outdir, "peak_regions.bed"))
    #nonpeak_regions.write_bed(os.path.join(args.outdir, "nonpeak_regions.bed"))

    #Sort according to order in bam_references:
    output_regions.loc_sort(bam_references)
    chrom_order = {bam_references[i]: i
                   for i in range(len(bam_references))
                   }  #for use later when sorting output

    #### Statistics about regions ####
    genome_bp = sum([region.get_length() for region in regions_dict["genome"]])
    for key in regions_dict:
        total_bp = sum([region.get_length() for region in regions_dict[key]])
        logger.stats("{0}: {1} regions | {2} bp | {3:.2f}% coverage".format(
            key, len(regions_dict[key]), total_bp, total_bp / genome_bp * 100))

    #Estallish variables for regions to be used
    input_regions = regions_dict["input_regions"]
    output_regions = regions_dict["output_regions"]
    peak_regions = regions_dict["peak_regions"]
    nonpeak_regions = regions_dict["nonpeak_regions"]

    #Exit if no input/output regions were found
    if len(input_regions) == 0 or len(output_regions) == 0 or len(
            peak_regions) == 0 or len(nonpeak_regions) == 0:
        logger.error("No regions found - exiting!")
        sys.exit()

    #----------------------------------------------------------------------------------------------------#
    # Estimate normalization factors
    #----------------------------------------------------------------------------------------------------#

    #Setup logger queue
    logger.debug("Setting up listener for log")
    logger.start_logger_queue()
    args.log_q = logger.queue

    #----------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("----- Estimating normalization factors -----")

    #If normalization is to be calculated
    if not args.norm_off:

        #Reads in peaks/nonpeaks
        logger.info("Counting reads in peak regions")
        peak_region_chunks = peak_regions.chunks(args.split)
        reads_peaks = sum(
            run_parallel(count_reads, peak_region_chunks, [args], args.cores,
                         logger))
        logger.comment("")

        logger.info("Counting reads in nonpeak regions")
        nonpeak_region_chunks = nonpeak_regions.chunks(args.split)
        reads_nonpeaks = sum(
            run_parallel(count_reads, nonpeak_region_chunks, [args],
                         args.cores, logger))

        reads_total = reads_peaks + reads_nonpeaks

        logger.stats("TOTAL_READS\t{0}".format(reads_total))
        logger.stats("PEAK_READS\t{0}".format(reads_peaks))
        logger.stats("NONPEAK_READS\t{0}".format(reads_nonpeaks))

        lib_norm = 10000000 / reads_total
        frip = reads_peaks / reads_total
        correct_factor = lib_norm * (1 / frip)

        logger.stats("LIB_NORM\t{0:.5f}".format(lib_norm))
        logger.stats("FRiP\t{0:.5f}".format(frip))
    else:
        logger.info("Normalization was switched off")
        correct_factor = 1.0

    logger.stats("CORRECTION_FACTOR:\t{0:.5f}".format(correct_factor))

    #----------------------------------------------------------------------------------------------------#
    # Estimate sequence bias
    #----------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("Started estimation of sequence bias...")

    input_region_chunks = input_regions.chunks(
        args.split)  #split to 100 chunks (also decides the step of output)
    out_lst = run_parallel(bias_estimation, input_region_chunks, [args],
                           args.cores,
                           logger)  #Output is list of AtacBias objects

    #Join objects
    estimated_bias = out_lst[0]  #initialize object with first output
    for output in out_lst[1:]:
        estimated_bias.join(
            output
        )  #bias object contains bias/background SequenceMatrix objects

    logger.debug("Bias estimated\tno_reads: {0}".format(
        estimated_bias.no_reads))

    #----------------------------------------------------------------------------------------------------#
    # Join estimations from all chunks of regions
    #----------------------------------------------------------------------------------------------------#

    bias_obj = estimated_bias
    bias_obj.correction_factor = correct_factor

    ### Bias motif ###
    logger.info("Finalizing bias motif for scoring")
    for strand in strands:
        bias_obj.bias[strand].prepare_mat()

        logger.debug("Saving pssm to figure pdf")
        fig = plot_pssm(bias_obj.bias[strand].pssm,
                        "Tn5 insertion bias of reads ({0})".format(strand))
        figure_pdf.savefig(fig)

    #Write bias motif to pickle
    out_f = os.path.join(args.outdir, args.prefix + "_AtacBias.pickle")
    logger.debug("Saving bias object to pickle ({0})".format(out_f))
    bias_obj.to_pickle(out_f)

    #----------------------------------------------------------------------------------------------------#
    # Correct read bias and write to bigwig
    #----------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("----- Correcting reads from .bam within output regions -----")

    output_regions.loc_sort(bam_references)  #sort in order of references
    output_regions_chunks = output_regions.chunks(args.split)
    no_tasks = float(len(output_regions_chunks))
    chunk_sizes = [len(chunk) for chunk in output_regions_chunks]
    logger.debug("All regions chunked: {0} ({1})".format(
        len(output_regions), chunk_sizes))

    ### Create key-file linking for bigwigs
    key2file = {}
    for track in output_bws:
        for strand in output_bws[track]:
            filename = output_bws[track][strand]["fn"]
            key = "{}:{}".format(track, strand)
            key2file[key] = filename

    #Start correction/write cores
    n_bigwig = len(key2file.values())
    writer_cores = min(n_bigwig, max(
        1, int(args.cores *
               0.1)))  #at most one core per bigwig or 10% of cores (or 1)
    worker_cores = max(1, args.cores - writer_cores)
    logger.debug("Worker cores: {0}".format(worker_cores))
    logger.debug("Writer cores: {0}".format(writer_cores))

    worker_pool = mp.Pool(processes=worker_cores)
    writer_pool = mp.Pool(processes=writer_cores)
    manager = mp.Manager()

    #Start bigwig file writers
    writer_tasks = []
    header = [(chrom, bam_chrom_info[chrom]) for chrom in bam_references]
    key_chunks = [
        list(key2file.keys())[i::writer_cores] for i in range(writer_cores)
    ]
    qs_list = []
    qs = {}
    for chunk in key_chunks:
        logger.debug("Creating writer queue for {0}".format(chunk))

        q = manager.Queue()
        qs_list.append(q)

        files = [key2file[key] for key in chunk]
        writer_tasks.append(
            writer_pool.apply_async(bigwig_writer,
                                    args=(q, dict(zip(chunk, files)), header,
                                          output_regions, args))
        )  #, callback = lambda x: finished.append(x) print("Writing time: {0}".format(x)))
        for key in chunk:
            qs[key] = q

    args.qs = qs
    writer_pool.close()  #no more jobs applied to writer_pool

    #Start correction
    logger.debug("Starting correction")
    task_list = [
        worker_pool.apply_async(bias_correction, args=[chunk, args, bias_obj])
        for chunk in output_regions_chunks
    ]
    worker_pool.close()
    monitor_progress(task_list, logger, "Correction progress:"
                     )  #does not exit until tasks in task_list finished
    results = [task.get() for task in task_list]

    #Get all results
    pre_bias = results[0][0]  #initialize with first result
    post_bias = results[0][1]  #initialize with first result
    for result in results[1:]:
        pre_bias_chunk = result[0]
        post_bias_chunk = result[1]

        for direction in strands:
            pre_bias[direction].add_counts(pre_bias_chunk[direction])
            post_bias[direction].add_counts(post_bias_chunk[direction])

    #Stop all queues for writing
    logger.debug("Stop all queues by inserting None")
    for q in qs_list:
        q.put((None, None, None))

    #Fetch error codes from bigwig writers
    logger.debug("Fetching possible errors from bigwig_writer tasks")
    results = [task.get()
               for task in writer_tasks]  #blocks until writers are finished

    logger.debug("Joining bigwig_writer queues")

    qsum = sum([q.qsize() for q in qs_list])
    while qsum != 0:
        qsum = sum([q.qsize() for q in qs_list])
        logger.spam("- Queue sizes {0}".format([(key, qs[key].qsize())
                                                for key in qs]))
        time.sleep(0.5)

    #Waits until all queues are closed
    writer_pool.join()
    worker_pool.terminate()
    worker_pool.join()

    #Stop multiprocessing logger
    logger.stop_logger_queue()

    #----------------------------------------------------------------------------------------------------#
    # Information and verification of corrected read frequencies
    #----------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("Verifying bias correction")

    #Calculating variance per base
    for strand in strands:

        #Invert negative counts
        abssum = np.abs(np.sum(post_bias[strand].neg_counts, axis=0))
        post_bias[strand].neg_counts = post_bias[strand].neg_counts + abssum

        #Join negative/positive counts
        post_bias[strand].counts += post_bias[strand].neg_counts  #now pos

        pre_bias[strand].prepare_mat()
        post_bias[strand].prepare_mat()

        pre_var = np.mean(np.var(pre_bias[strand].bias_pwm,
                                 axis=1)[:4])  #mean of variance per nucleotide
        post_var = np.mean(np.var(post_bias[strand].bias_pwm, axis=1)[:4])
        logger.stats("BIAS\tpre-bias variance {0}:\t{1:.7f}".format(
            strand, pre_var))
        logger.stats("BIAS\tpost-bias variance {0}:\t{1:.7f}".format(
            strand, post_var))

        #Plot figure
        fig_title = "Nucleotide frequencies in corrected reads\n({0} strand)".format(
            strand)
        figure_pdf.savefig(
            plot_correction(pre_bias[strand].bias_pwm,
                            post_bias[strand].bias_pwm, fig_title))

    #----------------------------------------------------------------------------------------------------#
    # Finish up
    #----------------------------------------------------------------------------------------------------#

    plt.close('all')
    figure_pdf.close()
    logger.end()
Пример #10
0
def run_motifclust(args):

    ###### Check input arguments ######
    check_required(args, ["motifs"])  #Check input arguments
    check_files([args.motifs])  #Check if files exist
    out_cons_img = os.path.join(args.outdir, "consensus_motifs_img")
    make_directory(out_cons_img)
    out_prefix = os.path.join(args.outdir, args.prefix)

    ###### Create logger and write argument overview ######
    logger = TobiasLogger("ClusterMotifs", args.verbosity)
    logger.begin()

    parser = add_motifclust_arguments(argparse.ArgumentParser())
    logger.arguments_overview(parser, args)
    #logger.output_files([])

    out_prefix = os.path.join(args.outdir, args.prefix)

    #----------------------------------------- Check for gimmemotifs ----------------------------------------#

    try:
        from gimmemotifs.motif import Motif
        from gimmemotifs.comparison import MotifComparer
        sns.set_style(
            "ticks"
        )  #set style back to ticks, as this is set globally during gimmemotifs import
    except:
        logger.error(
            "MotifClust requires the python package 'gimmemotifs'. You can install it using 'pip install gimmemotifs' or 'conda install gimmemotifs'."
        )
        sys.exit()

    #---------------------------------------- Reading motifs from file(s) -----------------------------------#
    logger.info("Reading input file(s)")

    motif_list = MotifList()  #list containing OneMotif objects
    motif_dict = {}  #dictionary containing separate motif lists per file

    if sys.version_info < (
            3, 7, 0):  # workaround for deepcopy with python version < 3.5
        copy._deepcopy_dispatch[type(re.compile(''))] = lambda r, _: r

    for f in args.motifs:
        logger.debug("Reading {0}".format(f))

        motif_format = get_motif_format(open(f).read())
        sub_motif_list = MotifList().from_file(f)  #MotifList object

        logger.stats("- Read {0} motifs from {1} (format: {2})".format(
            len(sub_motif_list), f, motif_format))

        motif_list.extend(sub_motif_list)
        motif_dict[f] = sub_motif_list

    #Check whether ids are unique
    #TODO

    #---------------------------------------- Motif stats ---------------------------------------------------#
    logger.info("Creating matrix statistics")

    gimmemotifs_list = [
        motif.get_gimmemotif().gimme_obj for motif in motif_list
    ]

    #Stats for all motifs
    full_motifs_out = out_prefix + "_stats_motifs.txt"
    motifs_stats = get_motif_stats(gimmemotifs_list)
    write_motif_stats(motifs_stats, full_motifs_out)

    #---------------------------------------- Motif clustering ----------------------------------------------#
    logger.info("Clustering motifs")

    clusters = motif_list.cluster(threshold=args.threshold,
                                  metric=args.dist_method,
                                  clust_method=args.clust_method)
    logger.stats("- Identified {0} clusters".format(len(clusters)))

    #Write out overview of clusters
    cluster_dict = {
        cluster_id: [
            motif.get_gimmemotif().gimme_obj.id
            for motif in clusters[cluster_id]
        ]
        for cluster_id in clusters
    }
    cluster_f = out_prefix + "_" + "clusters.yml"
    logger.info("- Writing clustering to {0}".format(cluster_f))
    write_yaml(cluster_dict, cluster_f)

    # Save similarity matrix to file
    matrix_out = out_prefix + "_matrix.txt"
    logger.info("- Saving similarity matrix to the file: " + str(matrix_out))
    motif_list.similarity_matrix.to_csv(matrix_out, sep='\t')

    #Plot dendrogram
    logger.info("Plotting clustering dendrogram")
    dendrogram_f = out_prefix + "_" + "dendrogram." + args.type  #plot format pdf/png
    plot_dendrogram(motif_list.similarity_matrix.columns,
                    motif_list.linkage_mat, 12, dendrogram_f, "Clustering",
                    args.threshold, args.dpi)

    #---------------------------------------- Consensus motif -----------------------------------------------#
    logger.info("Building consensus motif for each cluster")

    consensus_motifs = MotifList()
    for cluster_id in clusters:
        consensus = clusters[cluster_id].create_consensus(
        )  #MotifList object with create_consensus method
        consensus.id = cluster_id if len(
            clusters[cluster_id]) > 1 else clusters[cluster_id][
                0].id  #set original motif id if cluster length = 1

        consensus_motifs.append(consensus)

    #Write out consensus motif file
    out_f = out_prefix + "_consensus_motifs." + args.cons_format
    logger.info("- Writing consensus motifs to: {0}".format(out_f))
    consensus_motifs.to_file(out_f, args.cons_format)

    #Create logo plots
    out_cons_img = os.path.join(args.outdir, "consensus_motifs_img")
    logger.info(
        "- Making logo plots for consensus motifs (output folder: {0})".format(
            out_cons_img))
    for motif in consensus_motifs:
        filename = os.path.join(out_cons_img,
                                motif.id + "_consensus." + args.type)
        motif.logo_to_file(filename)

    #---------------------------------------- Plot heatmap --------------------------------------------------#

    logger.info("Plotting similarity heatmap")
    logger.info(
        "Note: Can take a while for --type=pdf. Try \"--type png\" for speed up."
    )
    args.nrc = False
    args.ncc = False
    args.zscore = "None"
    clust_linkage = motif_list.linkage_mat
    similarity_matrix = motif_list.similarity_matrix

    pdf_out = out_prefix + "_heatmap_all." + args.type
    x_label = "All motifs"
    y_label = "All motifs"
    plot_heatmap(similarity_matrix, pdf_out, clust_linkage, clust_linkage,
                 args.dpi, x_label, y_label, args.color, args.ncc, args.nrc,
                 args.zscore)

    # Plot heatmaps for each combination of motif files
    comparisons = itertools.combinations(args.motifs, 2)
    for i, (motif_file_1, motif_file_2) in enumerate(comparisons):

        pdf_out = out_prefix + "_heatmap" + str(i) + "." + args.type
        logger.info("Plotting comparison of {0} and {1} motifs to the file ".
                    format(motif_file_1, motif_file_2) + str(pdf_out))

        x_label, y_label = motif_file_1, motif_file_2

        #Create subset of matrices for row/col clustering
        motif_names_1 = [
            motif.get_gimmemotif().gimme_obj.id
            for motif in motif_dict[motif_file_1]
        ]
        motif_names_2 = [
            motif.get_gimmemotif().gimme_obj.id
            for motif in motif_dict[motif_file_2]
        ]

        m1_matrix, m2_matrix, similarity_matrix_sub = subset_matrix(
            similarity_matrix, motif_names_1, motif_names_2)

        col_linkage = linkage(ssd.squareform(m1_matrix)) if (
            len(motif_names_1) > 1 and len(motif_names_2) > 1) else None
        row_linkage = linkage(ssd.squareform(m2_matrix)) if (
            len(motif_names_1) > 1 and len(motif_names_2) > 1) else None

        #Plot similarity heatmap between file1 and file2
        plot_heatmap(similarity_matrix_sub, pdf_out, col_linkage, row_linkage,
                     args.dpi, x_label, y_label, args.color, args.ncc,
                     args.nrc, args.zscore)

    # ClusterMotifs finished
    logger.end()
Пример #11
0
def run_network(args):

    make_directory(args.outdir)
    check_required(
        args, ["TFBS", "origin"])  #check if anything is given for parameters
    check_files([args.TFBS, args.origin])

    logger = TobiasLogger("CreateNetwork", args.verbosity)
    logger.begin()

    #-------------------------- Origin file translating motif name -> gene origin -----------------------------------#
    #translation file, where one motif can constitute more than one gene (jun::fos)
    #and more genes can encode transcription factors with same motifs (close family members with same target sequence)
    origin_table = pd.read_csv(args.origin, sep="\t", header=None)
    origin_table.columns = [
        "Origin_" + str(element) for element in origin_table.columns
    ]
    origin_table.fillna("", inplace=True)  #replace NaN with empty string

    #------------------------ Transcription factor binding sites with mapping to target genes -----------------------#

    logger.info("Reading all input binding sites")

    #todo: read in parallel
    dataframes = []
    for fil in args.TFBS:
        logger.debug("- {0}".format(fil))

        df = pd.read_csv(fil, sep="\t", header=None)
        dataframes.append(df)

    logger.debug("Joining sites from all files")
    sites_table = pd.concat(dataframes, sort=False)
    sites_table.columns = [
        "Sites_" + str(element) for element in sites_table.columns
    ]
    sites_table.fillna("", inplace=True)  #replace NaN with empty string
    logger.info("- Total of {0} sites found\n".format(sites_table.shape[0]))

    #---------------------------------------- Match target columns to origin ----------------------------------------#

    logger.info("Matching target genes back to TFs using --origin")

    origin_table_str_columns = list(
        origin_table.dtypes.index[origin_table.dtypes == "object"])
    sites_table_str_columns = list(
        sites_table.dtypes.index[sites_table.dtypes == "object"])

    origin_table = origin_table.apply(lambda x: x.astype(str).str.upper())
    sites_table = sites_table.apply(lambda x: x.astype(str).str.upper())

    #Establishing matched columns
    logger.debug(
        "Establishing which columns should be used for mapping target -> source"
    )

    matching = []
    for sites_column in sites_table_str_columns:
        sites_column_content = set(sites_table[sites_column])
        for origin_column in origin_table_str_columns:
            origin_column_content = set(origin_table[origin_column])

            #Overlap
            overlap = len(origin_column_content & sites_column_content)
            matching.append((sites_column, origin_column, overlap))

    sorted_matching = sorted(matching, key=lambda tup: -tup[-1])
    logger.debug("Match tuples: {0}".format(sorted_matching))

    #Columns for matching
    source_col_tfbs = sites_table.columns[
        3]  #Source id (TF name) is the 4th column of the sites bedfile
    source_col_origin = [
        match[1] for match in sorted_matching if match[0] == source_col_tfbs
    ][0]  #source id column (TF name) in the origin table
    target_col_tfbs = [
        match[0] for match in sorted_matching if match[0] != source_col_tfbs
    ][0]  #target id (gene id) column from sites
    target_col_origin = [
        match[1] for match in sorted_matching if match[0] != source_col_tfbs
    ][0]  #target id (gene id) in the origin table

    #Intersect of sources and targets
    source_ids_tfbs = set(sites_table[source_col_tfbs])
    logger.debug("Source ids from TFBS: {0}".format(
        ", ".join(list(source_ids_tfbs)[:4]) + " (...)"))

    source_ids_origin = set(origin_table[source_col_origin])
    logger.debug("Matched source ids from origin table: {0}".format(
        ", ".join(list(source_ids_origin)[:4]) + " (...)"))

    target_ids_tfbs = set(sites_table[target_col_tfbs])
    logger.debug("Target ids from TFBS: {0}".format(
        ", ".join(list(target_ids_tfbs)[:4]) + " (...)"))

    target_ids_origin = set(origin_table[target_col_origin])
    logger.debug("Matched target ids from origin table: {0}".format(
        ", ".join(list(target_ids_origin)[:4]) + " (...)"))

    common_ids = source_ids_tfbs & source_ids_origin
    if len(common_ids) != len(source_ids_tfbs):
        missing_ids = source_ids_tfbs - common_ids
        logger.warning(
            "The following source ids (4th column) from '--TFBS' could not be found in the '--origin' table: {0}"
            .format(missing_ids))

        #Subset sites_table to those with source within common_ids
        n_rows = sites_table.shape[0]
        sites_table = sites_table[sites_table[source_col_tfbs].isin(
            common_ids)]
        logger.warning(
            "Subset {0} sites to {1} sites with a valid source id in the '--origin' table"
            .format(n_rows, sites_table.shape[0]))

    #Remove sites without targets (NAN)
    n_rows = sites_table.shape[0]
    sites_table = sites_table[sites_table[target_col_tfbs] !=
                              ""]  #targets not empty
    logger.info("Subset {0} sites to {1} sites with any target given".format(
        n_rows, sites_table.shape[0]))

    #Subset sites_table to targets within target_ids_origin (through origin table) - e.g. only keep targets which are TFs
    n_rows = sites_table.shape[0]
    valid_targets = origin_table[target_col_origin]
    sites_table = sites_table[sites_table[target_col_tfbs].isin(valid_targets)]
    logger.info(
        "Subset {0} sites to {1} sites with matching target id in '--origin'".
        format(n_rows, sites_table.shape[0]))

    #Merge sites with origin table (to get target motif ids)
    n_rows = sites_table.shape[0]
    sites_table_convert = sites_table.merge(
        origin_table[[source_col_origin, target_col_origin]],
        left_on=target_col_tfbs,
        right_on=target_col_origin,
        how="inner")
    logger.info(
        "Merged sites/targets with '--origin' table: Continuing with {0} TF-target connections"
        .format(sites_table_convert.shape[0]))
    if n_rows < sites_table_convert.shape[0]:
        msg = "NOTE: The number of TF-target connections is higher than the number of unique sites, which occurs when the '--origin' table contains "
        msg += "target ids assigned to multiple motifs. In this case, CreateNetwork will treat each motif as an independent TF in the graph. "
        logger.info(msg)

    #Subset to unique edges if chosen
    #if args.unique == True:
    #	sites_table_convert.drop_duplicates(subset=[source_col_tfbs, source_col_origin], inplace=True)
    #	logger.info("Flag --unique is on: Sites were further subset to {0} unique edges".format(sites_table_convert.shape[0]))

    #------------------------------------ Write out edges / adjacency ----------------------------------------#
    logger.info("")  #create space in logger output

    ##### Write out edges #####
    edges_f = os.path.join(args.outdir, "edges.txt")
    logger.info("Writing edges list to: {0}".format(edges_f))
    sites_table_convert.to_csv(edges_f, sep="\t", index=False)

    ###### Create adjacency list ####
    logger.info("Creating adjacency matrix")
    adjacency = {source: {"targets": []} for source in common_ids}
    for index, row in sites_table_convert.iterrows():
        source, target = row[source_col_tfbs], row[source_col_origin]
        if target not in adjacency[source]["targets"]:
            adjacency[source]["targets"].append(target)

    adjacency_f = os.path.join(args.outdir, "adjacency.txt")
    logger.info("- Writing matrix to: {0}".format(adjacency_f))
    with open(adjacency_f, "w") as f:
        f.write("Source\tTargets\n")
        for source in sorted(adjacency):
            f.write("{0}\t{1}\n".format(
                source, ", ".join(adjacency[source]["targets"])))

    #-------------------------------------- Find paths through graph ---------------------------------------#

    #Create possible paths through graph
    logger.info("")
    logger.info("Create possible paths through graph")

    #Starting node can be used to subset path-finding to specific nodes; speeds up computation
    if args.start != None:
        start_nodes = [
            one_id for one_id in common_ids if args.start.upper() in one_id
        ]
        logger.info("Starting nodes are: {0}".format(start_nodes))
    else:
        start_nodes = common_ids
        logger.info(
            "Finding paths starting from all nodes. This behavior can be changed using --start."
        )

    #Find paths starting at source nodes
    for source_id in start_nodes:
        logger.info("- Finding paths starting from {0}".format(source_id))

        #Recursive function to find paths; initiate with no paths found
        paths = dfs(adjacency=adjacency,
                    path=[source_id],
                    all_paths=[],
                    options={"max_length": args.max_len})

        paths_f = os.path.join(args.outdir, "{0}_paths.txt".format(source_id))
        logger.debug("-- Writing paths to: " + paths_f)
        paths_out = open(paths_f, "w")
        paths_out.write("Regulatory_path\tn_nodes\n")

        path_edges = []  #Collect edges while writing paths
        for path in paths:
            if len(path) > 1:  #only write paths longer than 1 node

                #String formatting of path
                str_paths = ""
                for i, node in enumerate(path[:-1]):
                    str_paths += "{0} --> ".format(node)
                str_paths += path[-1]

                n_nodes = len(path)
                paths_out.write("{0}\t{1}\n".format(str_paths, n_nodes))

                #Make pairwise edges across path
                path_edges += [
                    "\t".join([path[i], path[j], str(j)])
                    for i, j in zip(range(0,
                                          len(path) - 1), range(1, len(path)))
                ]

        paths_out.close()

        #Write out the edges included in paths
        source_paths_f = os.path.join(args.outdir,
                                      "{0}_path_edges.txt".format(source_id))
        with open(source_paths_f, "w") as f:
            f.write("Source\tTarget\tLevel\n")
            for path in sorted(set(path_edges),
                               key=lambda tup: (tup[-1], tup[0])):
                f.write(path + "\n")

    #Finish CreateNetwork
    logger.end()
Пример #12
0
def run_bindetect(args):
    """ Main function to run bindetect algorithm with input files and parameters given in args """

    #Checking input and setting cond_names
    check_required(args, ["signals", "motifs", "genome", "peaks"])
    args.cond_names = [
        os.path.basename(os.path.splitext(bw)[0]) for bw in args.signals
    ] if args.cond_names is None else args.cond_names
    args.outdir = os.path.abspath(args.outdir)

    #Set output files
    states = ["bound", "unbound"]
    outfiles = [
        os.path.abspath(
            os.path.join(args.outdir, "*", "beds",
                         "*_{0}_{1}.bed".format(condition, state)))
        for (condition, state) in itertools.product(args.cond_names, states)
    ]
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir, "*", "beds", "*_all.bed")))
    outfiles.append(
        os.path.abspath(
            os.path.join(args.outdir, "*", "plots", "*_log2fcs.pdf")))
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir, "*", "*_overview.txt")))
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir, "*", "*_overview.xlsx")))

    outfiles.append(
        os.path.abspath(
            os.path.join(args.outdir, args.prefix + "_distances.txt")))
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir,
                                     args.prefix + "_results.txt")))
    outfiles.append(
        os.path.abspath(
            os.path.join(args.outdir, args.prefix + "_results.xlsx")))
    outfiles.append(
        os.path.abspath(os.path.join(args.outdir,
                                     args.prefix + "_figures.pdf")))

    #-------------------------------------------------------------------------------------------------------------#
    #-------------------------------------------- Setup logger and pool ------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger = TobiasLogger("BINDetect", args.verbosity)
    logger.begin()

    parser = add_bindetect_arguments(argparse.ArgumentParser())
    logger.arguments_overview(parser, args)
    logger.output_files(outfiles)

    # Setup pool
    args.cores = check_cores(args.cores, logger)
    writer_cores = max(1, int(args.cores * 0.1))
    worker_cores = max(1, args.cores - writer_cores)
    logger.debug("Worker cores: {0}".format(worker_cores))
    logger.debug("Writer cores: {0}".format(writer_cores))

    pool = mp.Pool(processes=worker_cores)
    writer_pool = mp.Pool(processes=writer_cores)

    #-------------------------------------------------------------------------------------------------------------#
    #-------------------------- Pre-processing data: Reading motifs, sequences, peaks ----------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.info("----- Processing input data -----")

    #Check opening/writing of files
    logger.info("Checking reading/writing of files")
    check_files([args.signals, args.motifs, args.genome, args.peaks],
                action="r")
    check_files(outfiles[-3:], action="w")
    make_directory(args.outdir)

    #Comparisons between conditions
    no_conditions = len(args.signals)
    if args.time_series:
        comparisons = list(zip(args.cond_names[:-1], args.cond_names[1:]))
        args.comparisons = comparisons
    else:
        comparisons = list(itertools.combinations(args.cond_names,
                                                  2))  #all-against-all
        args.comparisons = comparisons

    #Open figure pdf and write overview
    fig_out = os.path.abspath(
        os.path.join(args.outdir, args.prefix + "_figures.pdf"))
    figure_pdf = PdfPages(fig_out, keep_empty=True)

    plt.figure()
    plt.axis('off')
    plt.text(0.5,
             0.8,
             "BINDETECT FIGURES",
             ha="center",
             va="center",
             fontsize=20)

    #output and order
    titles = []
    titles.append("Raw score distributions")
    titles.append("Normalized score distributions")
    if args.debug:
        for (cond1, cond2) in comparisons:
            titles.append("Background log2FCs ({0} / {1})".format(
                cond1, cond2))

    for (cond1, cond2) in comparisons:
        titles.append("BINDetect plot ({0} / {1})".format(cond1, cond2))

    plt.text(0.1,
             0.6,
             "\n".join([
                 "Page {0}) {1}".format(i + 2, titles[i])
                 for i in range(len(titles))
             ]) + "\n\n",
             va="top")
    figure_pdf.savefig(bbox_inches='tight')
    plt.close()

    ################# Read peaks ################
    #Read peak and peak_header
    logger.info("Reading peaks")
    peaks = RegionList().from_bed(args.peaks)
    logger.info("- Found {0} regions in input peaks".format(len(peaks)))
    peaks = peaks.merge()  #merge overlapping peaks
    logger.info("- Merged to {0} regions".format(len(peaks)))

    if len(peaks) == 0:
        logger.error("Input --peaks file is empty!")
        sys.exit()

    #Read header and check match with number of peak columns
    peak_columns = len(peaks[0])  #number of columns
    if args.peak_header != None:
        content = open(args.peak_header, "r").read()
        args.peak_header_list = content.split()
        logger.debug("Peak header: {0}".format(args.peak_header_list))

        #Check whether peak header fits with number of peak columns
        if len(args.peak_header_list) != peak_columns:
            logger.error(
                "Length of --peak_header ({0}) does not fit number of columns in --peaks ({1})."
                .format(len(args.peak_header_list), peak_columns))
            sys.exit()
    else:
        args.peak_header_list = ["peak_chr", "peak_start", "peak_end"] + [
            "additional_" + str(num + 1) for num in range(peak_columns - 3)
        ]
    logger.debug("Peak header list: {0}".format(args.peak_header_list))

    ################# Check for match between peaks and fasta/bigwig #################
    logger.info(
        "Checking for match between --peaks and --fasta/--signals boundaries")
    logger.info("- Comparing peaks to {0}".format(args.genome))
    fasta_obj = pysam.FastaFile(args.genome)
    fasta_boundaries = dict(zip(fasta_obj.references, fasta_obj.lengths))
    fasta_obj.close()
    logger.debug("Fasta boundaries: {0}".format(fasta_boundaries))
    peaks = peaks.apply_method(OneRegion.check_boundary, fasta_boundaries,
                               "exit")  #will exit if peaks are outside borders

    #Check boundaries of each bigwig signal individually
    for signal in args.signals:
        logger.info("- Comparing peaks to {0}".format(signal))
        pybw_obj = pybw.open(signal)
        pybw_header = pybw_obj.chroms()
        pybw_obj.close()
        logger.debug("Signal boundaries: {0}".format(pybw_header))
        peaks = peaks.apply_method(OneRegion.check_boundary, pybw_header,
                                   "exit")

    ##### GC content for motif scanning ######
    #Make chunks of regions for multiprocessing
    logger.info("Estimating GC content from peak sequences")
    peak_chunks = peaks.chunks(args.split)
    gc_content_pool = pool.starmap(
        get_gc_content, itertools.product(peak_chunks, [args.genome]))
    gc_content = np.mean(gc_content_pool)  #fraction
    args.gc = gc_content
    bg = np.array([(1 - args.gc) / 2.0, args.gc / 2.0, args.gc / 2.0,
                   (1 - args.gc) / 2.0])
    logger.info("- GC content estimated at {0:.2f}%".format(gc_content * 100))

    ################ Get motifs ################
    logger.info("Reading motifs from file")
    motif_list = MotifList()
    args.motifs = expand_dirs(args.motifs)
    for f in args.motifs:
        motif_list += MotifList().from_file(f)  #List of OneMotif objects
    no_pfms = len(motif_list)
    logger.info("- Read {0} motifs".format(no_pfms))

    logger.debug("Getting motifs ready")
    motif_list.bg = bg

    logger.debug("Getting reverse motifs")
    motif_list.extend([motif.get_reverse() for motif in motif_list])
    logger.spam(motif_list)

    #Set prefixes
    for motif in motif_list:  #now with reverse motifs as well
        motif.set_prefix(args.naming)
        motif.bg = bg

        logger.spam("Getting pssm for motif {0}".format(motif.name))
        motif.get_pssm()

    motif_names = list(set([motif.prefix for motif in motif_list]))

    #Get threshold for motifs
    logger.debug("Getting match threshold per motif")
    outlist = pool.starmap(OneMotif.get_threshold,
                           itertools.product(motif_list, [args.motif_pvalue]))

    motif_list = MotifList(outlist)
    for motif in motif_list:
        logger.debug("Motif {0}: threshold {1}".format(motif.name,
                                                       motif.threshold))

    logger.info("Creating folder structure for each TF")
    for TF in motif_names:
        logger.spam("Creating directories for {0}".format(TF))
        make_directory(os.path.join(args.outdir, TF))
        make_directory(os.path.join(args.outdir, TF, "beds"))
        make_directory(os.path.join(args.outdir, TF, "plots"))

    #-------------------------------------------------------------------------------------------------------------#
    #----------------------------------------- Plot logos for all motifs -----------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    plus_motifs = [motif for motif in motif_list if motif.strand == "+"]
    logo_filenames = {
        motif.prefix: os.path.join(args.outdir, motif.prefix,
                                   motif.prefix + ".png")
        for motif in plus_motifs
    }

    logger.info("Plotting sequence logos for each motif")
    task_list = [
        pool.apply_async(OneMotif.logo_to_file, (
            motif,
            logo_filenames[motif.prefix],
        )) for motif in plus_motifs
    ]
    monitor_progress(task_list, logger)
    results = [task.get() for task in task_list]
    logger.comment("")

    logger.debug("Getting base64 strings per motif")
    for motif in motif_list:
        if motif.strand == "+":
            #motif.get_base()
            with open(logo_filenames[motif.prefix], "rb") as png:
                motif.base = base64.b64encode(png.read()).decode("utf-8")

    #-------------------------------------------------------------------------------------------------------------#
    #--------------------- Motif scanning: Find binding sites and match to footprint scores ----------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.start_logger_queue(
    )  #start process for listening and handling through the main logger queue
    args.log_q = logger.queue  #queue for multiprocessing logging
    manager = mp.Manager()
    logger.info("Scanning for motifs and matching to signals...")

    #Create writer queues for bed-file output
    logger.debug("Setting up writer queues")
    qs_list = []
    writer_qs = {}

    #writer_queue = create_writer_queue(key2file, writer_cores)
    #writer_queue.stop()	#wait until all are done

    manager = mp.Manager()
    TF_names_chunks = [
        motif_names[i::writer_cores] for i in range(writer_cores)
    ]
    for TF_names_sub in TF_names_chunks:
        logger.debug("Creating writer queue for {0}".format(TF_names_sub))
        files = [
            os.path.join(args.outdir, TF, "beds", TF + ".tmp")
            for TF in TF_names_sub
        ]

        q = manager.Queue()
        qs_list.append(q)

        writer_pool.apply_async(
            file_writer, args=(q, dict(zip(TF_names_sub, files)), args)
        )  #, callback = lambda x: finished.append(x) print("Writing time: {0}".format(x)))
        for TF in TF_names_sub:
            writer_qs[TF] = q
    writer_pool.close()  #no more jobs applied to writer_pool

    #todo: use run_parallel
    #Start working on data
    if worker_cores == 1:
        logger.debug("Running with cores = 1")
        results = []
        for chunk in peak_chunks:
            results.append(
                scan_and_score(chunk, motif_list, args, args.log_q, writer_qs))

    else:
        logger.debug("Sending jobs to worker pool")

        task_list = [
            pool.apply_async(scan_and_score, (
                chunk,
                motif_list,
                args,
                args.log_q,
                writer_qs,
            )) for chunk in peak_chunks
        ]
        monitor_progress(task_list, logger)
        results = [task.get() for task in task_list]

    logger.info("Done scanning for TFBS across regions!")
    #logger.stop_logger_queue()	#stop the listening process (wait until all was written)

    #--------------------------------------#
    logger.info("Waiting for bedfiles to write")

    #Stop all queues for writing
    logger.debug("Stop all queues by inserting None")
    for q in qs_list:
        q.put((None, None))

    logger.debug("Joining bed_writer queues")
    for i, q in enumerate(qs_list):
        logger.debug("- Queue {0} (size {1})".format(i, q.qsize()))

    #Waits until all queues are closed
    writer_pool.join()

    #-------------------------------------------------------------------------------------------------------------#
    #---------------------------- Process information on background scores and overlaps --------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.info("Merging results from subsets")
    background = merge_dicts([result[0] for result in results])
    TF_overlaps = merge_dicts([result[1] for result in results])
    results = None

    #Add missing TF overlaps (if some TFs had zero sites)
    for TF1 in plus_motifs:
        if TF1.prefix not in TF_overlaps:
            TF_overlaps[TF1.prefix] = 0
        for TF2 in plus_motifs:
            tup = (TF1.prefix, TF2.prefix)
            if tup not in TF_overlaps:
                TF_overlaps[tup] = 0

    #Collect sampled background values
    for bigwig in args.cond_names:
        background["signal"][bigwig] = np.array(background["signal"][bigwig])

    #Check how many values were fetched from background
    n_bg_values = len(background["signal"][args.cond_names[0]])
    logger.debug("Collected {0} values from background".format(n_bg_values))
    if n_bg_values < 1000:
        err_str = "Number of background values collected from peaks is low (={0}) ".format(
            n_bg_values)
        err_str += "- this affects estimation of the bound/unbound threshold and the normalization between conditions. "
        err_str += "To improve this estimation, please run BINDetect with --peaks = the full peak set across all conditions."
        logger.warning(err_str)

    logger.comment("")
    logger.info("Estimating score distribution per condition")

    fig = plot_score_distribution(
        [background["signal"][bigwig] for bigwig in args.cond_names],
        labels=args.cond_names,
        title="Raw scores per condition")
    figure_pdf.savefig(fig, bbox_inches='tight')
    plt.close()

    logger.info("Normalizing scores")
    list_of_vals = [background["signal"][bigwig] for bigwig in args.cond_names]
    normed, norm_objects = quantile_normalization(list_of_vals)

    args.norm_objects = dict(zip(args.cond_names, norm_objects))
    for bigwig in args.cond_names:
        background["signal"][bigwig] = args.norm_objects[bigwig].normalize(
            background["signal"][bigwig])

    fig = plot_score_distribution(
        [background["signal"][bigwig] for bigwig in args.cond_names],
        labels=args.cond_names,
        title="Normalized scores per condition")
    figure_pdf.savefig(fig, bbox_inches='tight')
    plt.close()

    ###########################################################
    logger.info("Estimating bound/unbound threshold")

    #Prepare scores (remove 0's etc.)
    bg_values = np.array(normed).flatten()
    bg_values = bg_values[np.logical_not(np.isclose(
        bg_values, 0.0))]  #only non-zero counts
    x_max = np.percentile(bg_values, [99])
    bg_values = bg_values[bg_values < x_max]

    if len(bg_values) == 0:
        logger.error(
            "Error processing bigwig scores from background. It could be that there are no scores in the bigwig (=0) assigned for the peaks. Please check your input files."
        )
        sys.exit()

    #Fit mixture of normals
    lowest_bic = np.inf
    for n_components in [2]:  #2 components
        gmm = sklearn.mixture.GaussianMixture(n_components=n_components,
                                              random_state=1)
        gmm.fit(np.log(bg_values).reshape(-1, 1))

        bic = gmm.bic(np.log(bg_values).reshape(-1, 1))
        logger.debug("n_compontents: {0} | bic: {1}".format(n_components, bic))
        if bic < lowest_bic:
            lowest_bic = bic
            best_gmm = gmm
    gmm = best_gmm

    #Extract most-right gaussian
    means = gmm.means_.flatten()
    sds = np.sqrt(gmm.covariances_).flatten()
    chosen_i = np.argmax(means)  #Mixture with largest mean

    log_params = scipy.stats.lognorm.fit(bg_values[bg_values < x_max],
                                         f0=sds[chosen_i],
                                         fscale=np.exp(means[chosen_i]))
    #all_log_params[bigwig] = log_params

    #Mode of distribution
    mode = scipy.optimize.fmin(
        lambda x: -scipy.stats.lognorm.pdf(x, *log_params), 0, disp=False)[0]
    logger.debug("- Mode estimated at: {0}".format(mode))
    pseudo = mode / 2.0  #pseudo is half the mode
    args.pseudo = pseudo
    logger.debug("Pseudocount estimated at: {0}".format(round(args.pseudo, 5)))

    # Estimate theoretical normal for threshold
    leftside_x = np.linspace(
        scipy.stats.lognorm(*log_params).ppf([0.01]), mode, 100)
    leftside_pdf = scipy.stats.lognorm.pdf(leftside_x, *log_params)

    #Flip over
    mirrored_x = np.concatenate([leftside_x,
                                 np.max(leftside_x) + leftside_x]).flatten()
    mirrored_pdf = np.concatenate([leftside_pdf, leftside_pdf[::-1]]).flatten()
    popt, cov = scipy.optimize.curve_fit(
        lambda x, std, sc: sc * scipy.stats.norm.pdf(x, mode, std), mirrored_x,
        mirrored_pdf)
    norm_params = (mode, popt[0])
    logger.debug("Theoretical normal parameters: {0}".format(norm_params))

    #Set threshold for bound/unbound
    threshold = round(
        scipy.stats.norm.ppf(1 - args.bound_pvalue, *norm_params), 5)

    args.thresholds = {bigwig: threshold for bigwig in args.cond_names}
    logger.stats("- Threshold estimated at: {0}".format(threshold))

    #Only plot if args.debug is True
    if args.debug:

        #Plot fit
        fig, ax = plt.subplots(1, 1)
        ax.hist(bg_values[bg_values < x_max],
                bins='auto',
                density=True,
                label="Observed score distribution")

        xvals = np.linspace(0, x_max, 1000)
        log_probas = scipy.stats.lognorm.pdf(xvals, *log_params)
        ax.plot(xvals, log_probas, label="Log-normal fit", color="orange")

        #Theoretical normal
        norm_probas = scipy.stats.norm.pdf(xvals, *norm_params)
        ax.plot(xvals,
                norm_probas * (np.max(log_probas) / np.max(norm_probas)),
                color="grey",
                linestyle="--",
                label="Theoretical normal")

        ax.axvline(threshold, color="black", label="Bound/unbound threshold")
        ymax = plt.ylim()[1]
        ax.text(threshold, ymax, "\n {0:.3f}".format(threshold), va="top")

        #Decorate plot
        plt.title("Score distribution")
        plt.xlabel("Bigwig score")
        plt.ylabel("Density")
        plt.legend(fontsize=8)
        plt.xlim((0, x_max))

        figure_pdf.savefig(fig)
        plt.close(fig)

    ############ Foldchanges between conditions ################
    logger.comment("")
    log2fc_params = {}
    if len(args.signals) > 1:
        logger.info(
            "Calculating background log2 fold-changes between conditions")

        for (bigwig1, bigwig2) in comparisons:  #cond1, cond2
            logger.info("- {0} / {1}".format(bigwig1, bigwig2))

            #Estimate background log2fc
            scores1 = np.copy(background["signal"][bigwig1])
            scores2 = np.copy(background["signal"][bigwig2])

            included = np.logical_or(scores1 > 0, scores2 > 0)
            scores1 = scores1[included]
            scores2 = scores2[included]

            #Calculate background log2fc normal disitribution
            log2fcs = np.log2(
                np.true_divide(scores1 + args.pseudo, scores2 + args.pseudo))

            lower, upper = np.percentile(log2fcs, [1, 99])
            log2fcs_fit = log2fcs[np.logical_and(log2fcs >= lower,
                                                 log2fcs <= upper)]

            norm_params = scipy.stats.norm.fit(log2fcs_fit)

            logger.debug(
                "({0} / {1}) Background log2fc normal distribution: {2}".
                format(bigwig1, bigwig2, norm_params))
            log2fc_params[(bigwig1, bigwig2)] = norm_params

            #Plot background log2fc to figures
            fig, ax = plt.subplots(1, 1)
            plt.hist(log2fcs,
                     density=True,
                     bins='auto',
                     label="Background log2fc ({0} / {1})".format(
                         bigwig1, bigwig2))

            xvals = np.linspace(plt.xlim()[0], plt.xlim()[1], 100)
            pdf = scipy.stats.norm.pdf(xvals,
                                       *log2fc_params[(bigwig1, bigwig2)])
            plt.plot(xvals, pdf, label="Normal distribution fit")
            plt.title("Background log2FCs ({0} / {1})".format(
                bigwig1, bigwig2))

            plt.xlabel("Log2 fold change")
            plt.ylabel("Density")
            if args.debug:
                figure_pdf.savefig(fig, bbox_inches='tight')

                f = open(
                    os.path.join(
                        args.outdir,
                        "{0}_{1}_log2fcs.txt".format(bigwig1, bigwig2)), "w")
                f.write("\n".join([str(val) for val in log2fcs]))
                f.close()
            plt.close()

    background = None  #free up space

    #-------------------------------------------------------------------------------------------------------------#
    #----------------------------- Read total sites per TF to estimate bound/unbound -----------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("Processing scanned TFBS individually")

    #Getting bindetect table ready
    info_columns = ["total_tfbs"]
    info_columns.extend([
        "{0}_{1}".format(cond, metric)
        for (cond, metric
             ) in itertools.product(args.cond_names, ["threshold", "bound"])
    ])
    info_columns.extend([
        "{0}_{1}_{2}".format(comparison[0], comparison[1], metric)
        for (comparison,
             metric) in itertools.product(comparisons, ["change", "pvalue"])
    ])

    cols = len(info_columns)
    rows = len(motif_names)
    info_table = pd.DataFrame(np.zeros((rows, cols)),
                              columns=info_columns,
                              index=motif_names)

    #Starting calculations
    results = []
    if args.cores == 1:
        for name in motif_names:
            logger.info("- {0}".format(name))
            results.append(process_tfbs(name, args, log2fc_params))
    else:
        logger.debug("Sending jobs to worker pool")

        task_list = [
            pool.apply_async(process_tfbs, (
                name,
                args,
                log2fc_params,
            )) for name in motif_names
        ]
        monitor_progress(task_list,
                         logger)  #will not exit before all jobs are done
        results = [task.get() for task in task_list]

    logger.info("Concatenating results from subsets")
    info_table = pd.concat(results)  #pandas tables

    pool.terminate()
    pool.join()

    logger.stop_logger_queue()

    #-------------------------------------------------------------------------------------------------------------#
    #------------------------------------------------ Cluster TFBS -----------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    clustering = RegionCluster(TF_overlaps)
    clustering.cluster()

    #Convert full ids to alt ids
    convert = {motif.prefix: motif.name for motif in motif_list}
    for cluster in clustering.clusters:
        for name in convert:
            clustering.clusters[cluster]["cluster_name"] = clustering.clusters[
                cluster]["cluster_name"].replace(name, convert[name])

    #Write out distance matrix
    matrix_out = os.path.join(args.outdir, args.prefix + "_distances.txt")
    clustering.write_distance_mat(matrix_out)

    #-------------------------------------------------------------------------------------------------------------#
    #----------------------------------------- Write all_bindetect file ------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    logger.comment("")
    logger.info("Writing all_bindetect files")

    #Add columns of name / motif_id / prefix
    names = []
    ids = []
    for prefix in info_table.index:
        motif = [motif for motif in motif_list if motif.prefix == prefix]
        names.append(motif[0].name)
        ids.append(motif[0].id)

    info_table.insert(0, "output_prefix", info_table.index)
    info_table.insert(1, "name", names)
    info_table.insert(2, "motif_id", ids)

    #info_table.insert(3, "motif_logo", [os.path.join("motif_logos", os.path.basename(logo_filenames[prefix])) for prefix in info_table["output_prefix"]])	#add relative path to logo

    #Add cluster to info_table
    cluster_names = []
    for name in info_table.index:
        for cluster in clustering.clusters:
            if name in clustering.clusters[cluster]["member_names"]:
                cluster_names.append(
                    clustering.clusters[cluster]["cluster_name"])

    info_table.insert(3, "cluster", cluster_names)

    #Cluster table on motif clusters
    info_table_clustered = info_table.groupby(
        "cluster").mean()  #mean of each column
    info_table_clustered.reset_index(inplace=True)

    #Map correct type
    info_table["total_tfbs"] = info_table["total_tfbs"].map(int)
    for condition in args.cond_names:
        info_table[condition + "_bound"] = info_table[condition +
                                                      "_bound"].map(int)

    #### Write excel ###
    bindetect_excel = os.path.join(args.outdir, args.prefix + "_results.xlsx")
    writer = pd.ExcelWriter(bindetect_excel, engine='xlsxwriter')

    #Tables
    info_table.to_excel(writer, index=False, sheet_name="Individual motifs")
    info_table_clustered.to_excel(writer,
                                  index=False,
                                  sheet_name="Motif clusters")

    for sheet in writer.sheets:
        worksheet = writer.sheets[sheet]
        n_rows = worksheet.dim_rowmax
        n_cols = worksheet.dim_colmax
        worksheet.autofilter(0, 0, n_rows, n_cols)
    writer.save()

    #Format comparisons
    for (cond1, cond2) in comparisons:
        base = cond1 + "_" + cond2
        info_table[base + "_change"] = info_table[base + "_change"].round(5)
        info_table[base + "_pvalue"] = info_table[base + "_pvalue"].map(
            "{:.5E}".format, na_action="ignore")

    #Write bindetect results tables
    #info_table.insert(0, "TF_name", info_table.index)	 #Set index as first column
    bindetect_out = os.path.join(args.outdir, args.prefix + "_results.txt")
    info_table.to_csv(bindetect_out,
                      sep="\t",
                      index=False,
                      header=True,
                      na_rep="NA")

    #-------------------------------------------------------------------------------------------------------------#
    #------------------------------------------- Make BINDetect plot ---------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    if no_conditions > 1:
        logger.info("Creating BINDetect plot(s)")

        #Fill NAs from info_table to enable plotting of log2fcs (NA -> 0 change)
        change_cols = [col for col in info_table.columns if "_change" in col]
        pvalue_cols = [col for col in info_table.columns if "_pvalue" in col]
        info_table[change_cols] = info_table[change_cols].fillna(0)
        info_table[pvalue_cols] = info_table[pvalue_cols].fillna(1)

        #Plotting bindetect per comparison
        for (cond1, cond2) in comparisons:

            logger.info("- {0} / {1}".format(cond1, cond2))
            base = cond1 + "_" + cond2

            #Define which motifs to show
            xvalues = info_table[base + "_change"].astype(float)
            yvalues = info_table[base + "_pvalue"].astype(float)
            y_min = np.percentile(yvalues[yvalues > 0],
                                  5)  #5% smallest pvalues
            x_min, x_max = np.percentile(
                xvalues, [5, 95])  #5% smallest and largest changes

            #Make copy of motifs and fill in with metadata
            comparison_motifs = [
                motif for motif in motif_list if motif.strand == "+"
            ]  #copy.deepcopy(motif_list) - swig pickle error, just overwrite motif_list
            for motif in comparison_motifs:
                name = motif.prefix
                motif.change = float(info_table.at[name, base + "_change"])
                motif.pvalue = float(info_table.at[name, base + "_pvalue"])
                motif.logpvalue = -np.log10(
                    motif.pvalue) if motif.pvalue > 0 else -np.log10(1e-308)

                #Assign each motif to group
                if motif.change < x_min or motif.change > x_max or motif.pvalue < y_min:
                    if motif.change < 0:
                        motif.group = cond2 + "_up"
                    if motif.change > 0:
                        motif.group = cond1 + "_up"
                else:
                    motif.group = "n.s."

            #Bindetect plot
            fig = plot_bindetect(comparison_motifs, clustering, [cond1, cond2],
                                 args)
            figure_pdf.savefig(fig, bbox_inches='tight')
            plt.close(fig)

            #Interactive BINDetect plot
            html_out = os.path.join(args.outdir, "bindetect_" + base + ".html")
            plot_interactive_bindetect(comparison_motifs, [cond1, cond2],
                                       html_out)

    #-------------------------------------------------------------------------------------------------------------#
    #----------------------------- Make heatmap across conditions (for debugging)---------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    if args.debug:

        mean_columns = [cond + "_mean_score" for cond in args.cond_names]
        heatmap_table = info_table[mean_columns]
        heatmap_table.index = info_table["output_prefix"]

        #Decide fig size
        rows, cols = heatmap_table.shape
        figsize = (7 + cols, max(10, rows / 8.0))
        cm = sns.clustermap(
            heatmap_table,
            figsize=figsize,
            z_score=0,  #zscore for rows
            col_cluster=False,  #do not cluster condition columns
            yticklabels=True,  #show all row annotations
            xticklabels=True,
            cbar_pos=(0, 0, .4, .005),
            dendrogram_ratio=(0.3, 0.01),
            cbar_kws={
                "orientation": "horizontal",
                'label': 'Row z-score'
            },
            method="single")

        #Adjust width of columns
        #hm = cm.ax_heatmap.get_position()
        #cm.ax_heatmap.set_position([hm.x0, hm.y0, cols * 3 * hm.height / rows, hm.height]) 	#aspect should be equal

        plt.setp(cm.ax_heatmap.get_xticklabels(),
                 fontsize=8,
                 rotation=45,
                 ha="right")
        plt.setp(cm.ax_heatmap.get_yticklabels(), fontsize=5)

        cm.ax_col_dendrogram.set_title('Mean scores across conditions',
                                       fontsize=20)
        cm.ax_heatmap.set_ylabel("Transcription factor motifs",
                                 fontsize=15,
                                 rotation=270)
        #cm.ax_heatmap.set_title('Conditions')
        #cm.fig.suptitle('Mean scores across conditions')
        #cm.cax.set_visible(False)

        #Save to output pdf
        plt.tight_layout()
        figure_pdf.savefig(cm.fig, bbox_inches='tight')
        plt.close(cm.fig)

    #-------------------------------------------------------------------------------------------------------------#
    #-------------------------------------------------- Wrap up---------------------------------------------------#
    #-------------------------------------------------------------------------------------------------------------#

    figure_pdf.close()
    logger.end()
Пример #13
0
def run_motifclust(args):

    ###### Check input arguments ######
    check_required(args, ["motifs"])  #Check input arguments
    check_files([args.motifs])  #Check if files exist
    out_cons_img = os.path.join(args.outdir, "consensus_motifs_img")
    make_directory(out_cons_img)
    out_prefix = os.path.join(args.outdir, args.prefix)

    ###### Create logger and write argument overview ######
    logger = TobiasLogger("ClusterMotifs", args.verbosity)
    logger.begin()

    parser = add_motifclust_arguments(argparse.ArgumentParser())
    logger.arguments_overview(parser, args)
    #logger.output_files([])

    out_prefix = os.path.join(args.outdir, args.prefix)

    #----------------------------------------- Check for gimmemotifs ----------------------------------------#

    try:
        from gimmemotifs.motif import Motif
        from gimmemotifs.comparison import MotifComparer
        sns.set_style(
            "ticks"
        )  #set style back to ticks, as this is set globally during gimmemotifs import

    except ModuleNotFoundError:
        logger.error(
            "MotifClust requires the python package 'gimmemotifs'. You can install it using 'pip install gimmemotifs' or 'conda install gimmemotifs'."
        )
        sys.exit(1)

    except ImportError as e:  #gimmemotifs was installed, but there was an error during import

        pandas_version = pd.__version__
        python_version = platform.python_version()

        if e.name == "collections" and (
                version.parse(python_version) >= version.parse("3.10.0")
        ):  #collections error from norns=0.1.5 and from other packages on python=3.10
            logger.error(
                "Due to package dependency errors, 'TOBIAS ClusterMotifs' is not available for python>=3.10. Current python version is '{0}'. Please downgrade python in order to use this tool."
                .format(python_version))
            sys.exit(1)

        elif e.name == "pandas.core.indexing" and (
                version.parse(pandas_version) >= version.parse("1.3.0")):
            logger.error(
                "Package 'gimmemotifs' version < 0.17.0 requires 'pandas' version < 1.3.0. Current pandas version is {0}."
                .format(pandas_version))
            sys.exit(1)

        else:  #other import error
            logger.error(
                "Tried to import package 'gimmemotifs' but failed with error: '{0}'"
                .format(repr(e)))
            logger.error("Traceback:")
            raise e

    except Exception as e:
        logger.error(
            "Tried to import package 'gimmemotifs' but failed with error: '{0}'"
            .format(repr(e)))
        logger.error(
            "Please check that 'gimmemotifs' was successfully installed.")
        sys.exit(1)

    #Check gimmemotifs version vs. metric given
    import gimmemotifs
    gimme_version = gimmemotifs.__version__
    if gimme_version == "0.17.0" and args.dist_method in ["pcc", "akl"]:
        logger.warning(
            "The dist_method given ('{0}') is invalid for gimmemotifs version 0.17.0. Please choose another --dist_method. See also: https://github.com/vanheeringen-lab/gimmemotifs/issues/243"
            .format(args.dist_method))
        sys.exit(1)

    #---------------------------------------- Reading motifs from file(s) -----------------------------------#
    logger.info("Reading input file(s)")

    motif_list = MotifList()  #list containing OneMotif objects
    motif_dict = {}  #dictionary containing separate motif lists per file

    if sys.version_info < (
            3, 7, 0):  # workaround for deepcopy with python version < 3.5
        copy._deepcopy_dispatch[type(re.compile(''))] = lambda r, _: r

    for f in args.motifs:
        logger.debug("Reading {0}".format(f))

        motif_format = get_motif_format(open(f).read())
        sub_motif_list = MotifList().from_file(f)  #MotifList object

        logger.stats("- Read {0} motifs from {1} (format: {2})".format(
            len(sub_motif_list), f, motif_format))

        motif_list.extend(sub_motif_list)
        motif_dict[f] = sub_motif_list

    #Check whether ids are unique
    #TODO

    #---------------------------------------- Motif stats ---------------------------------------------------#
    logger.info("Creating matrix statistics")

    gimmemotifs_list = [
        motif.get_gimmemotif().gimme_obj for motif in motif_list
    ]

    #Stats for all motifs
    full_motifs_out = out_prefix + "_stats_motifs.txt"
    motifs_stats = get_motif_stats(gimmemotifs_list)
    write_motif_stats(motifs_stats, full_motifs_out)

    #---------------------------------------- Motif clustering ----------------------------------------------#
    logger.info("Clustering motifs")

    clusters = motif_list.cluster(threshold=args.threshold,
                                  metric=args.dist_method,
                                  clust_method=args.clust_method)
    logger.stats("- Identified {0} clusters".format(len(clusters)))

    #Write out overview of clusters
    cluster_dict = {
        cluster_id: [
            motif.get_gimmemotif().gimme_obj.id
            for motif in clusters[cluster_id]
        ]
        for cluster_id in clusters
    }
    cluster_f = out_prefix + "_" + "clusters.yml"
    logger.info("- Writing clustering to {0}".format(cluster_f))
    write_yaml(cluster_dict, cluster_f)

    # Save similarity matrix to file
    matrix_out = out_prefix + "_matrix.txt"
    logger.info("- Saving similarity matrix to the file: " + str(matrix_out))
    motif_list.similarity_matrix.to_csv(matrix_out, sep='\t')

    #Plot dendrogram
    logger.info("Plotting clustering dendrogram")
    dendrogram_f = out_prefix + "_" + "dendrogram." + args.type  #plot format pdf/png
    plot_dendrogram(motif_list.similarity_matrix.columns,
                    motif_list.linkage_mat, 12, dendrogram_f, "Clustering",
                    args.threshold, args.dpi)

    #---------------------------------------- Consensus motif -----------------------------------------------#
    logger.info("Building consensus motif for each cluster")

    consensus_motifs = MotifList()
    for cluster_id in clusters:
        consensus = clusters[cluster_id].create_consensus(
            metric=args.dist_method
        )  #MotifList object with create_consensus method
        consensus.id = cluster_id if len(
            clusters[cluster_id]) > 1 else clusters[cluster_id][
                0].id  #set original motif id if cluster length = 1

        consensus_motifs.append(consensus)

    #Write out consensus motif file
    out_f = out_prefix + "_consensus_motifs." + args.cons_format
    logger.info("- Writing consensus motifs to: {0}".format(out_f))
    consensus_motifs.to_file(out_f, args.cons_format)

    #Create logo plots
    out_cons_img = os.path.join(args.outdir, "consensus_motifs_img")
    logger.info(
        "- Making logo plots for consensus motifs (output folder: {0})".format(
            out_cons_img))
    for motif in consensus_motifs:
        filename = os.path.join(out_cons_img,
                                motif.id + "_consensus." + args.type)
        motif.logo_to_file(filename)

    #---------------------------------------- Plot heatmap --------------------------------------------------#

    logger.info("Plotting similarity heatmap")
    logger.info(
        "Note: Can take a while for --type=pdf. Try \"--type png\" for speed up."
    )
    args.nrc = False
    args.ncc = False
    args.zscore = "None"
    clust_linkage = motif_list.linkage_mat
    similarity_matrix = motif_list.similarity_matrix

    pdf_out = out_prefix + "_heatmap_all." + args.type
    x_label = "All motifs"
    y_label = "All motifs"
    plot_heatmap(similarity_matrix, pdf_out, clust_linkage, clust_linkage,
                 args.dpi, x_label, y_label, args.color, args.ncc, args.nrc,
                 args.zscore)

    # Plot heatmaps for each combination of motif files
    comparisons = itertools.combinations(args.motifs, 2)
    for i, (motif_file_1, motif_file_2) in enumerate(comparisons):

        pdf_out = out_prefix + "_heatmap" + str(i) + "." + args.type
        logger.info("Plotting comparison of {0} and {1} motifs to the file ".
                    format(motif_file_1, motif_file_2) + str(pdf_out))

        x_label, y_label = motif_file_1, motif_file_2

        #Create subset of matrices for row/col clustering
        motif_names_1 = [
            motif.get_gimmemotif().gimme_obj.id
            for motif in motif_dict[motif_file_1]
        ]
        motif_names_2 = [
            motif.get_gimmemotif().gimme_obj.id
            for motif in motif_dict[motif_file_2]
        ]

        m1_matrix, m2_matrix, similarity_matrix_sub = subset_matrix(
            similarity_matrix, motif_names_1, motif_names_2)

        col_linkage = linkage(ssd.squareform(m1_matrix)) if (
            len(motif_names_1) > 1 and len(motif_names_2) > 1) else None
        row_linkage = linkage(ssd.squareform(m2_matrix)) if (
            len(motif_names_1) > 1 and len(motif_names_2) > 1) else None

        #Plot similarity heatmap between file1 and file2
        plot_heatmap(similarity_matrix_sub, pdf_out, col_linkage, row_linkage,
                     args.dpi, x_label, y_label, args.color, args.ncc,
                     args.nrc, args.zscore)

    # ClusterMotifs finished
    logger.end()
Пример #14
0
def run_aggregate(args):
    """ Function to make aggregate plot given input from args """

    #########################################################################################
    ############################## Setup logger/input/output ################################
    #########################################################################################

    logger = TobiasLogger("PlotAggregate", args.verbosity)
    logger.begin()

    parser = add_aggregate_arguments(argparse.ArgumentParser())
    logger.arguments_overview(parser, args)
    logger.output_files([args.output])

    #Check input parameters
    check_required(args, ["TFBS", "signals"])
    check_files([
        args.TFBS, args.signals, args.regions, args.whitelist, args.blacklist
    ],
                action="r")
    check_files([args.output], action="w")

    #### Test input ####
    if args.TFBS_labels != None and (len(args.TFBS) != len(args.TFBS_labels)):
        logger.error(
            "ERROR --TFBS and --TFBS-labels have different lengths ({0} vs. {1})"
            .format(len(args.TFBS), len(args.TFBS_labels)))
        sys.exit(1)
    if args.region_labels != None and (len(args.regions) != len(
            args.region_labels)):
        logger.error(
            "ERROR: --regions and --region-labels have different lengths ({0} vs. {1})"
            .format(len(args.regions), len(args.region_labels)))
        sys.exit(1)
    if args.signal_labels != None and (len(args.signals) != len(
            args.signal_labels)):
        logger.error(
            "ERROR: --signals and --signal-labels have different lengths ({0} vs. {1})"
            .format(len(args.signals), len(args.signal_labels)))
        sys.exit(1)

    #### Format input ####
    args.TFBS_labels = [
        os.path.splitext(os.path.basename(f))[0] for f in args.TFBS
    ] if args.TFBS_labels == None else args.TFBS_labels
    args.region_labels = [
        os.path.splitext(os.path.basename(f))[0] for f in args.regions
    ] if args.region_labels == None else args.region_labels
    args.signal_labels = [
        os.path.splitext(os.path.basename(f))[0] for f in args.signals
    ] if args.signal_labels == None else args.signal_labels

    #TFBS labels cannot be the same
    if len(set(args.TFBS_labels)) < len(
            args.TFBS_labels):  #this indicates duplicates
        logger.error(
            "ERROR: --TFBS-labels are not allowed to contain duplicates. Note that '--TFBS-labels' are created automatically from the '--TFBS'-files if no input was given."
            + "Please check that neither contain duplicate names")
        sys.exit(1)

    #########################################################################################
    ############################ Get input regions/signals ready ############################
    #########################################################################################

    logger.info("---- Processing input ----")
    logger.info("Reading information from .bed-files")

    #Make combinations of TFBS / regions
    region_names = []

    if len(args.regions) > 0:
        logger.info("Overlapping sites to --regions")
        regions_dict = {}

        combis = itertools.product(range(len(args.TFBS)),
                                   range(len(args.regions)))
        for (i, j) in combis:
            TFBS_f = args.TFBS[i]
            region_f = args.regions[j]

            #Make overlap
            pb_tfbs = pb.BedTool(TFBS_f)
            pb_region = pb.BedTool(region_f)
            #todo: write out lengths

            overlap = pb_tfbs.intersect(pb_region, u=True)
            #todo: length after overlap

            name = args.TFBS_labels[i] + " <OVERLAPPING> " + args.region_labels[
                j]  #name for column
            region_names.append(name)
            regions_dict[name] = RegionList().from_bed(overlap.fn)

            if args.negate == True:
                overlap_neg = pb_tfbs.intersect(pb_region, v=True)

                name = args.TFBS_labels[
                    i] + " <NOT OVERLAPPING> " + args.region_labels[j]
                region_names.append(name)
                regions_dict[name] = RegionList().from_bed(overlap_neg.fn)

    else:
        region_names = args.TFBS_labels
        regions_dict = {
            args.TFBS_labels[i]: RegionList().from_bed(args.TFBS[i])
            for i in range(len(args.TFBS))
        }

        for name in region_names:
            logger.stats("COUNT {0}: {1} sites".format(
                name, len(regions_dict[name])))  #length of RegionList obj

    #-------- Do overlap of regions if whitelist / blacklist -------#

    if len(args.whitelist) > 0 or len(args.blacklist) > 0:
        logger.info("Subsetting regions on whitelist/blacklist")
        for regions_id in regions_dict:
            sites = pb.BedTool(regions_dict[regions_id].as_bed(),
                               from_string=True)
            logger.stats("Found {0} sites in {1}".format(
                len(regions_dict[regions_id]), regions_id))

            if len(args.whitelist) > 0:
                for whitelist_f in args.whitelist:
                    whitelist = pb.BedTool(whitelist_f)
                    sites_tmp = sites.intersect(whitelist, u=True)
                    sites = sites_tmp
                    logger.stats("Overlapped to whitelist -> {0}".format(
                        len(sites)))

            if len(args.blacklist) > 0:
                for blacklist_f in args.blacklist:
                    blacklist = pb.BedTool(blacklist_f)
                    sites_tmp = sites.intersect(blacklist, v=True)
                    sites = sites_tmp
                    logger.stats("Removed blacklist -> {0}".format(
                        format(len(sites))))

            regions_dict[regions_id] = RegionList().from_bed(sites.fn)

    # Estimate motif width per --TFBS
    motif_widths = {}
    for regions_id in regions_dict:
        site_list = regions_dict[regions_id]
        if len(site_list) > 0:
            motif_widths[regions_id] = site_list[0].get_width()
        else:
            motif_widths[regions_id] = 0

    #########################################################################################
    ############################ Read signal for bigwig per site ############################
    #########################################################################################

    logger.info("Reading signal from bigwigs")

    args.width = args.flank * 2  #output regions will be of args.width

    signal_dict = {}
    for i, signal_f in enumerate(args.signals):

        signal_name = args.signal_labels[i]
        signal_dict[signal_name] = {}

        #Open pybw to read signal
        pybw = pyBigWig.open(signal_f)
        boundaries = pybw.chroms()  #dictionary of {chrom: length}

        logger.info("- Reading signal from {0}".format(signal_name))
        for regions_id in regions_dict:

            original = copy.deepcopy(regions_dict[regions_id])

            # Set width (centered on mid)
            regions_dict[regions_id].apply_method(OneRegion.set_width,
                                                  args.width)

            #Check that regions are within boundaries and remove if not
            invalid = [
                i for i, region in enumerate(regions_dict[regions_id])
                if region.check_boundary(boundaries, action="remove") == None
            ]
            for invalid_idx in invalid[::-1]:  #idx from higher to lower
                logger.warning(
                    "Region '{reg}' ('{orig}' before flank extension) from bed regions '{id}' is out of chromosome boundaries. This region will be excluded from output."
                    .format(reg=regions_dict[regions_id][invalid_idx].pretty(),
                            orig=original[invalid_idx].pretty(),
                            id=regions_id))
                del regions_dict[regions_id][invalid_idx]

            #Get signal from remaining regions
            for one_region in regions_dict[regions_id]:
                tup = one_region.tup()  #(chr, start, end, strand)
                if tup not in signal_dict[
                        signal_name]:  #only get signal if it was not already read previously
                    signal_dict[signal_name][tup] = one_region.get_signal(
                        pybw, logger=logger, key=signal_name)  #returns signal

        pybw.close()

    #########################################################################################
    ################################## Calculate aggregates #################################
    #########################################################################################

    signal_names = args.signal_labels

    #Calculate aggregate per signal/region comparison
    logger.info("Calculating aggregate signals")
    aggregate_dict = {
        signal_name: {region_name: []
                      for region_name in regions_dict}
        for signal_name in signal_names
    }
    for row, signal_name in enumerate(signal_names):
        for col, region_name in enumerate(region_names):

            signalmat = np.array([
                signal_dict[signal_name][reg.tup()]
                for reg in regions_dict[region_name]
            ])

            #Check shape of signalmat
            if signalmat.shape[0] == 0:  #no regions
                logger.warning(
                    "No regions left for '{0}'. The aggregate for this signal will be set to 0."
                    .format(signal_name))
                aggregate = np.zeros(args.width)
            else:

                #Exclude outlier rows
                max_values = np.max(signalmat, axis=1)
                upper_limit = np.percentile(
                    max_values, [100 * args.remove_outliers
                                 ])[0]  #remove-outliers is a fraction
                logical = max_values <= upper_limit
                logger.debug(
                    "{0}:{1}\tUpper limit: {2} (regions removed: {3})".format(
                        signal_name, region_name, upper_limit,
                        len(signalmat) - sum(logical)))
                signalmat = signalmat[logical]

                #Log-transform values before aggregating
                if args.log_transform:
                    signalmat_abs = np.abs(signalmat)
                    signalmat_log = np.log2(signalmat_abs + 1)
                    signalmat_log[
                        signalmat < 0] *= -1  #original negatives back to <0
                    signalmat = signalmat_log

                aggregate = np.nanmean(signalmat, axis=0)

                #normalize between 0-1
                if args.normalize:
                    aggregate = preprocessing.minmax_scale(aggregate)

                if args.smooth > 1:
                    aggregate_extend = np.pad(aggregate, args.smooth, "edge")
                    aggregate_smooth = fast_rolling_math(
                        aggregate_extend.astype('float64'), args.smooth,
                        "mean")
                    aggregate = aggregate_smooth[args.smooth:-args.smooth]

            aggregate_dict[signal_name][region_name] = aggregate
            signalmat = None  #free up space

    signal_dict = None  #free up space

    #########################################################################################
    ############################## Write aggregates to file #################################
    #########################################################################################

    if args.output_txt is not None:

        #Open file for writing
        f_out = open(args.output_txt, "w")
        f_out.write("### AGGREGATE\n")
        f_out.write("# Signal\tRegions\tAggregate\n")
        for row, signal_name in enumerate(signal_names):
            for col, region_name in enumerate(region_names):

                agg = aggregate_dict[signal_name][region_name]
                agg_txt = ",".join(["{:.4f}".format(val) for val in agg])

                f_out.write("{0}\t{1}\t{2}\n".format(signal_name, region_name,
                                                     agg_txt))

        f_out.close()

    #########################################################################################
    ################################## Footprint measures ###################################
    #########################################################################################

    logger.comment("")
    logger.info("---- Analysis ----")

    #Measure of footprint depth in comparison to baseline
    logger.info("Calculating footprint depth measure")
    logger.info("FPD (signal,regions): footprint_width baseline middle FPD")
    for row, signal_name in enumerate(signal_names):
        for col, region_name in enumerate(region_names):

            agg = aggregate_dict[signal_name][region_name]
            motif_width = motif_widths[region_name]

            #Estimation of possible footprint width
            FPD_results = []
            for fp_flank in range(int(motif_width / 2),
                                  min([25, args.flank
                                       ])):  #motif width for this bed

                #Baseline level
                baseline_indices = list(range(
                    0, args.flank - fp_flank)) + list(
                        range(args.flank + fp_flank, len(agg)))
                baseline = np.mean(agg[baseline_indices])

                #Footprint level
                middle_indices = list(
                    range(args.flank - fp_flank, args.flank + fp_flank))
                middle = np.mean(agg[middle_indices])  #within the motif

                #Footprint depth
                depth = middle - baseline
                FPD_results.append([fp_flank * 2, baseline, middle, depth])

            #Estimation of possible footprint width
            all_fpds = [result[-1] for result in FPD_results]
            FPD_results_best = FPD_results  #[result + ["  "] if result[-1] != min(all_fpds) else result + ["*"] for result in FPD_results]

            for result in FPD_results_best:
                logger.stats(
                    "FPD ({0},{1}): {2} {3:.5f} {4:.5f} {5:.5f}".format(
                        signal_name, region_name, result[0], result[1],
                        result[2], result[3]))

    #Compare pairwise to calculate correlation of signals
    logger.comment("")
    logger.info("Calculating measures for comparing pairwise aggregates")
    logger.info(
        "CORRELATION (signal1,region1) VS (signal2,region2): PEARSONR\tSUM_DIFF"
    )
    plots = itertools.product(signal_names, region_names)
    combis = itertools.combinations(plots, 2)

    for ax1, ax2 in combis:
        signal1, region1 = ax1
        signal2, region2 = ax2
        agg1 = aggregate_dict[signal1][region1]
        agg2 = aggregate_dict[signal2][region2]

        pearsonr, pval = scipy.stats.pearsonr(agg1, agg2)

        diff = np.sum(np.abs(agg1 -
                             agg2))  #Sum of difference between agg1 and agg2

        logger.stats(
            "CORRELATION ({0},{1}) VS ({2},{3}): {4:.5f}\t{5:.5f}".format(
                signal1, region1, signal2, region2, pearsonr, diff))

    #########################################################################################
    ################################ Set up plotting grid ###################################
    #########################################################################################

    logger.comment("")
    logger.info("---- Plotting aggregates ----")
    logger.info("Setting up plotting grid")

    n_signals = len(signal_names)
    n_regions = len(region_names)  #regions are set of sites

    signal_compare = True if n_signals > 1 else False
    region_compare = True if n_regions > 1 else False

    #Define whether signal is on x/y
    if args.signal_on_x:

        #x-axis
        n_cols = n_signals
        col_compare = signal_compare
        col_names = signal_names

        #y-axis
        n_rows = n_regions
        row_compare = region_compare
        row_names = region_names
    else:
        #x-axis
        n_cols = n_regions
        col_compare = region_compare
        col_names = region_names

        #y-axis
        n_rows = n_signals
        row_compare = signal_compare
        row_names = signal_names

    #Compare across rows/cols?
    if row_compare:
        n_rows += 1
        row_names += ["Comparison"]
    if col_compare:
        n_cols += 1
        col_names += ["Comparison"]

    #Set grid
    fig, axarr = plt.subplots(n_rows,
                              n_cols,
                              figsize=(n_cols * 5, n_rows * 5),
                              constrained_layout=True)
    axarr = np.array(axarr).reshape(
        (-1,
         1)) if n_cols == 1 else axarr  #Fix indexing for one column figures
    axarr = np.array(axarr).reshape(
        (1, -1)) if n_rows == 1 else axarr  #Fix indexing for one row figures

    #X axis / Y axis labels
    #mainax = fig.add_subplot(111, frameon=False)
    #mainax.set_xlabel("X label", labelpad=30, fontsize=16)
    #mainax.set_ylabel("Y label", labelpad=30, fontsize=16)
    #mainax.xaxis.set_label_position('top')

    #Title of plot and grid
    plt.suptitle(
        " " * 7 + args.title, fontsize=16
    )  #Add a little whitespace to center the title on the plot; not the frame

    #Titles per column
    for col in range(n_cols):
        title = col_names[col].replace(" ", "\n")
        l = max([len(line) for line in title.split("\n")
                 ])  #length of longest line in title
        s = fontsize_func(l)  #decide fontsize based on length
        axarr[0, col].set_title(title, fontsize=s)

    #Titles (ylabels) per row
    for row in range(n_rows):
        label = row_names[row]
        l = max([len(line) for line in label.split("\n")])
        axarr[row, 0].set_ylabel(label, fontsize=fontsize_func(l))

    #Colors
    colors = mpl.cm.brg(
        np.linspace(0, 1,
                    len(signal_names) + len(region_names)))

    #xvals
    flank = int(args.width / 2.0)
    xvals = np.arange(-flank, flank + 1)
    xvals = np.delete(xvals, flank)

    #Settings for each subplot
    for row in range(n_rows):
        for col in range(n_cols):
            axarr[row, col].set_xlim(-flank, flank)
            axarr[row, col].set_xlabel('bp from center')
            #axarr[row, col].set_ylabel('Mean aggregated signal')
            minor_ticks = np.arange(-flank, flank, args.width / 10.0)

    #Settings for comparison plots
    a = [
        axarr[-1, col].set_facecolor("0.9") if row_compare == True else 0
        for col in range(n_cols)
    ]
    a = [
        axarr[row, -1].set_facecolor("0.9") if col_compare == True else 0
        for row in range(n_rows)
    ]

    #########################################################################################
    ####################### Fill in grid with aggregate bigwig scores #######################
    #########################################################################################

    for si in range(n_signals):
        signal_name = signal_names[si]
        for ri in range(n_regions):
            region_name = region_names[ri]

            logger.info("Plotting regions {0} from signal {1}".format(
                region_name, signal_name))

            row, col = (ri, si) if args.signal_on_x else (si, ri)

            #If there are any regions:
            if len(regions_dict[region_name]) > 0:

                #Signal in region
                aggregate = aggregate_dict[signal_name][region_name]
                axarr[row, col].plot(xvals,
                                     aggregate,
                                     color=colors[col + row],
                                     linewidth=1,
                                     label=signal_name)

                #Compare across rows and cols
                if col_compare:  #compare between different columns by adding one more column
                    axarr[row, -1].plot(xvals,
                                        aggregate,
                                        color=colors[row + col],
                                        linewidth=1,
                                        alpha=0.8,
                                        label=col_names[col])

                    s = min([
                        ax.title.get_fontproperties()._size
                        for ax in axarr[0, :]
                    ])  #smallest fontsize of all columns
                    axarr[row, -1].legend(loc="lower right", fontsize=s)

                if row_compare:  #compare between different rows by adding one more row

                    axarr[-1, col].plot(xvals,
                                        aggregate,
                                        color=colors[row + col],
                                        linewidth=1,
                                        alpha=0.8,
                                        label=row_names[row])

                    s = min([
                        ax.yaxis.label.get_fontproperties()._size
                        for ax in axarr[:, 0]
                    ])  #smallest fontsize of all rows
                    axarr[-1, col].legend(loc="lower right", fontsize=s)

                #Diagonal comparison
                if n_rows == n_cols and col_compare and row_compare and col == row:
                    axarr[-1, -1].plot(xvals,
                                       aggregate,
                                       color=colors[row + col],
                                       linewidth=1,
                                       alpha=0.8)

                #Add number of sites to plot
                axarr[row, col].text(0.98,
                                     0.98,
                                     str(len(regions_dict[region_name])),
                                     transform=axarr[row, col].transAxes,
                                     fontsize=12,
                                     va="top",
                                     ha="right")

                #Motif boundaries (can only be compared across sites)
                if args.plot_boundaries:

                    #Get motif width for this list of TFBS
                    width = motif_widths[region_names[min(
                        row, n_rows -
                        2)]] if args.signal_on_x else motif_widths[
                            region_names[min(col, n_cols - 2)]]

                    mstart = -np.floor(width / 2.0)
                    mend = np.ceil(
                        width / 2.0) - 1  #as it spans the "0" nucleotide
                    axarr[row, col].axvline(mstart,
                                            color="grey",
                                            linestyle="dashed",
                                            linewidth=1)
                    axarr[row, col].axvline(mend,
                                            color="grey",
                                            linestyle="dashed",
                                            linewidth=1)

    #------------- Finishing up plots ---------------#

    logger.info("Adjusting final details")

    #remove lower-right corner if not applicable
    if n_rows != n_cols and n_rows > 1 and n_cols > 1:
        axarr[-1, -1].axis('off')
        axarr[-1, -1] = None

    #Check whether share_y is set
    if args.share_y == "none":
        pass

    #Comparable rows (rowcompare = same across all)
    elif (args.share_y == "signals"
          and args.signal_on_x == False) or (args.share_y == "sites"
                                             and args.signal_on_x == True):
        for col in range(n_cols):
            lims = np.array(
                [ax.get_ylim() for ax in axarr[:, col] if ax is not None])
            ymin, ymax = np.min(lims), np.max(lims)

            #Set limit across rows for this col
            for row in range(n_rows):
                if axarr[row, col] is not None:
                    axarr[row, col].set_ylim(ymin, ymax)

    #Comparable columns (colcompare = same across all)
    elif (args.share_y == "sites"
          and args.signal_on_x == False) or (args.share_y == "signals"
                                             and args.signal_on_x == True):
        for row in range(n_rows):
            lims = np.array(
                [ax.get_ylim() for ax in axarr[row, :] if ax is not None])
            ymin, ymax = np.min(lims), np.max(lims)

            #Set limit across cols for this row
            for col in range(n_cols):
                if axarr[row, col] is not None:
                    axarr[row, col].set_ylim(ymin, ymax)

    #Comparable on both rows/columns
    elif args.share_y == "both":
        global_ymin, global_ymax = np.inf, -np.inf
        for row in range(n_rows):
            for col in range(n_cols):
                if axarr[row, col] is not None:
                    local_ymin, local_ymax = axarr[row, col].get_ylim()
                    global_ymin = local_ymin if local_ymin < global_ymin else global_ymin
                    global_ymax = local_ymax if local_ymax > global_ymax else global_ymax

        for row in range(n_rows):
            for col in range(n_cols):
                if axarr[row, col] is not None:
                    axarr[row, col].set_ylim(global_ymin, global_ymax)

    #Force plots to be square
    for row in range(n_rows):
        for col in range(n_cols):
            forceSquare(axarr[row, col])

    plt.savefig(args.output, bbox_inches='tight')
    plt.close()

    logger.end()
Пример #15
0
def run_scorebigwig(args):
	
	check_required(args, ["signal", "output", "regions"])
	check_files([args.signal, args.regions], "r")
	check_files([args.output], "w")

	#---------------------------------------------------------------------------------------#
	# Create logger and write info to log
	#---------------------------------------------------------------------------------------#

	logger = TobiasLogger("ScoreBigwig", args.verbosity)
	logger.begin()

	parser = add_scorebigwig_arguments(argparse.ArgumentParser())
	logger.arguments_overview(parser, args)
	logger.output_files([args.output])

	logger.debug("Setting up listener for log")
	logger.start_logger_queue()
	args.log_q = logger.queue

	#---------------------------------------------------------------------------------------#
	#----------------------- I/O - get regions/bigwig ready --------------------------------#
	#---------------------------------------------------------------------------------------#

	logger.info("Processing input files")

	logger.info("- Opening input cutsite bigwig")
	pybw_signal = pyBigWig.open(args.signal)
	pybw_header = pybw_signal.chroms()
	chrom_info = {chrom:int(pybw_header[chrom]) for chrom in pybw_header}
	logger.debug("Chromosome lengths from input bigwig: {0}".format(chrom_info))

	#Decide regions 
	logger.info("- Getting output regions ready")
	if args.regions:
		regions = RegionList().from_bed(args.regions)

		#Check whether regions are available in input bigwig
		not_in_bigwig = list(set(regions.get_chroms()) - set(chrom_info.keys()))
		if len(not_in_bigwig) > 0:
			logger.warning("Contigs {0} were found in input --regions, but were not found in input --signal. These regions cannot be scored and will therefore be excluded from output.".format(not_in_bigwig))
			regions = regions.remove_chroms(not_in_bigwig)
		
		regions.apply_method(OneRegion.extend_reg, args.extend)
		regions.merge()
		regions.apply_method(OneRegion.check_boundary, chrom_info, "cut")

	else:
		regions = RegionList().from_chrom_lengths(chrom_info)

	#Set flank to enable scoring in ends of regions
	if args.score == "sum":
		args.region_flank = int(args.window/2.0)
	elif args.score == "footprint" or args.score == "FOS":
		args.region_flank = int(args.flank_max)
	else:
		args.region_flank = 0

	#Go through each region
	for i, region in enumerate(regions):
		region.extend_reg(args.region_flank)
		region = region.check_boundary(chrom_info, "cut")
		region.extend_reg(-args.region_flank)

	#Information for output bigwig
	reference_chroms = sorted(list(chrom_info.keys()))
	header = [(chrom, chrom_info[chrom]) for chrom in reference_chroms]
	regions.loc_sort(reference_chroms)

	#---------------------------------------------------------------------------------------#
	#------------------------ Calculating footprints and writing out -----------------------#
	#---------------------------------------------------------------------------------------#

	logger.info("Calculating footprints in regions...")
	regions_chunks = regions.chunks(args.split)

	#Setup pools
	args.cores = check_cores(args.cores, logger)
	writer_cores = 1	
	worker_cores = max(1, args.cores - writer_cores)
	logger.debug("Worker cores: {0}".format(worker_cores))
	logger.debug("Writer cores: {0}".format(writer_cores))

	worker_pool = mp.Pool(processes=worker_cores)
	writer_pool = mp.Pool(processes=writer_cores)
	manager = mp.Manager()

	#Start bigwig file writers
	q = manager.Queue()
	writer_pool.apply_async(bigwig_writer, args=(q, {"scores":args.output}, header, regions, args))
	writer_pool.close() #no more jobs applied to writer_pool
	writer_qs = {"scores": q}

	args.writer_qs = writer_qs

	#Start calculating scores
	pool = mp.Pool(processes=args.cores)
	task_list = [pool.apply_async(calculate_scores, args=[chunk, args]) for chunk in regions_chunks]
	no_tasks = len(task_list)
	pool.close()
	monitor_progress(task_list, logger)
	results = [task.get() for task in task_list]

	#Stop all queues for writing
	logger.debug("Stop all queues by inserting None")
	for q in writer_qs.values():
		q.put((None, None, None))

	#Done computing
	writer_pool.join() 
	worker_pool.terminate()
	worker_pool.join()
	
	logger.stop_logger_queue()

	#Finished scoring
	logger.end()
Пример #16
0
def bias_correction(regions_list, params, bias_obj):
    """ Corrects bias in cutsites (from bamfile) using estimated bias """

    logger = TobiasLogger("", params.verbosity, params.log_q)

    bam_f = params.bam
    fasta_f = params.genome
    k_flank = params.k_flank
    read_shift = params.read_shift
    L = 2 * k_flank + 1
    w = params.window
    f = int(w / 2.0)
    qs = params.qs

    f_extend = k_flank + f

    strands = ["forward", "reverse"]
    pre_bias = {strand: SequenceMatrix.create(L, "PWM") for strand in strands}
    post_bias = {strand: SequenceMatrix.create(L, "PWM") for strand in strands}

    #Open bamfile and fasta
    bam_obj = pysam.AlignmentFile(bam_f, "rb")
    fasta_obj = pysam.FastaFile(fasta_f)

    out_signals = {}

    #Go through each region
    for region_obj in regions_list:

        region_obj.extend_reg(f_extend)
        reg_len = region_obj.get_length()  #length including flanking
        reg_key = (region_obj.chrom, region_obj.start + f_extend,
                   region_obj.end - f_extend)  #output region
        out_signals[reg_key] = {
            "uncorrected": {},
            "bias": {},
            "expected": {},
            "corrected": {}
        }

        ################################
        ####### Uncorrected reads ######
        ################################

        #Get cutsite positions for each read
        read_lst = ReadList().from_bam(bam_obj, region_obj)
        for read in read_lst:
            read.get_cutsite(read_shift)
        logger.spam("Read {0} reads from region {1}".format(
            len(read_lst), region_obj))

        #Exclude reads with cutsites outside region
        read_lst = ReadList([
            read for read in read_lst if read.cutsite > region_obj.start
            and read.cutsite < region_obj.end
        ])
        for_lst, rev_lst = read_lst.split_strands()
        read_lst_strand = {"forward": for_lst, "reverse": rev_lst}

        for strand in strands:
            out_signals[reg_key]["uncorrected"][strand] = read_lst_strand[
                strand].signal(region_obj)
            out_signals[reg_key]["uncorrected"][strand] = np.round(
                out_signals[reg_key]["uncorrected"][strand], 5)

        ################################
        ###### Estimation of bias ######
        ################################

        #Get sequence in this region
        sequence_obj = GenomicSequence(region_obj).from_fasta(fasta_obj)

        #Score sequence using forward/reverse motifs
        for strand in strands:
            if strand == "forward":
                seq = sequence_obj.sequence
                bias = bias_obj.bias[strand].score_sequence(seq)
            elif strand == "reverse":
                seq = sequence_obj.revcomp
                bias = bias_obj.bias[strand].score_sequence(seq)[::-1]  #3'-5'

            out_signals[reg_key]["bias"][strand] = np.nan_to_num(
                bias)  #convert any nans to 0

        #################################
        ###### Correction of reads ######
        #################################

        reg_end = reg_len - k_flank
        step = 10
        overlaps = int(params.window / step)
        window_starts = list(range(k_flank, reg_end - params.window, step))
        window_ends = list(range(k_flank + params.window, reg_end, step))
        window_ends[-1] = reg_len
        windows = list(zip(window_starts, window_ends))

        for strand in strands:

            ########### Estimate bias threshold ###########
            bias_predictions = np.zeros((overlaps, reg_len))
            row = 0

            for window in windows:

                signal_w = out_signals[reg_key]["uncorrected"][strand][
                    window[0]:window[1]]
                bias_w = out_signals[reg_key]["bias"][strand][
                    window[0]:window[1]]

                signalmax = np.max(signal_w)
                #biasmin = np.min(bias_w)
                #biasmax = np.max(bias_w)

                if signalmax > 0:
                    try:
                        popt, pcov = curve_fit(relu, bias_w, signal_w)
                        bias_predict = relu(bias_w, *popt)

                    except (OptimizeWarning, RuntimeError):
                        cut_positions = np.logical_not(np.isclose(signal_w, 0))
                        bias_min = np.min(bias_w[cut_positions])
                        bias_predict = bias_w - bias_min
                        bias_predict[bias_predict < 0] = 0

                    if np.max(bias_predict) > 0:
                        bias_predict = bias_predict / np.max(bias_predict)
                else:
                    bias_predict = np.zeros(window[1] - window[0])

                bias_predictions[row, window[0]:window[1]] = bias_predict
                row += 1 if row < overlaps - 1 else 0

            bias_prediction = np.mean(bias_predictions, axis=0)
            bias = bias_prediction

            ######## Calculate expected signal ######
            signal_sum = fast_rolling_math(
                out_signals[reg_key]["uncorrected"][strand], w, "sum")
            signal_sum[np.isnan(signal_sum)] = 0  #f-width ends of region

            bias_sum = fast_rolling_math(bias, w, "sum")  #ends of arr are nan
            nulls = np.logical_or(np.isclose(bias_sum, 0), np.isnan(bias_sum))
            bias_sum[nulls] = 1  # N-regions will give stretches of 0-bias
            bias_probas = bias / bias_sum
            bias_probas[nulls] = 0  #nan to 0

            out_signals[reg_key]["expected"][strand] = signal_sum * bias_probas

            ######## Correct signal ########
            out_signals[reg_key]["uncorrected"][
                strand] *= bias_obj.correction_factor
            out_signals[reg_key]["expected"][
                strand] *= bias_obj.correction_factor
            out_signals[reg_key]["corrected"][
                strand] = out_signals[reg_key]["uncorrected"][
                    strand] - out_signals[reg_key]["expected"][strand]

            ######## Rescale signal to fit uncorrected sum ########
            uncorrected_sum = fast_rolling_math(
                out_signals[reg_key]["uncorrected"][strand], w, "sum")
            uncorrected_sum[np.isnan(uncorrected_sum)] = 0
            corrected_sum = fast_rolling_math(
                np.abs(out_signals[reg_key]["corrected"][strand]), w,
                "sum")  #negative values count as positive
            corrected_sum[np.isnan(corrected_sum)] = 0

            #Positive signal left after correction
            corrected_pos = np.copy(out_signals[reg_key]["corrected"][strand])
            corrected_pos[corrected_pos < 0] = 0
            corrected_pos_sum = fast_rolling_math(corrected_pos, w, "sum")
            corrected_pos_sum[np.isnan(corrected_pos_sum)] = 0
            corrected_neg_sum = corrected_sum - corrected_pos_sum

            #The corrected sum is less than the signal sum, so scale up positive cuts
            zero_sum = corrected_pos_sum == 0
            corrected_pos_sum[zero_sum] = np.nan  #allow for zero division
            scale_factor = (uncorrected_sum -
                            corrected_neg_sum) / corrected_pos_sum
            scale_factor[
                zero_sum] = 1  #Scale factor is 1 (which will be multiplied to the 0 values)
            scale_factor[scale_factor < 1] = 1  #Only scale up if needed
            pos_bool = out_signals[reg_key]["corrected"][strand] > 0
            out_signals[reg_key]["corrected"][strand][
                pos_bool] *= scale_factor[pos_bool]

        #######################################
        ########   Verify correction   ########
        #######################################

        #Verify correction across all reads
        for strand in strands:
            for idx in range(k_flank, reg_len - k_flank - 1):
                if idx > k_flank and idx < reg_len - k_flank:

                    orig = out_signals[reg_key]["uncorrected"][strand][idx]
                    correct = out_signals[reg_key]["corrected"][strand][idx]

                    if orig != 0 or correct != 0:  #if both are 0, don't add to pre/post bias
                        if strand == "forward":
                            kmer = sequence_obj.sequence[idx - k_flank:idx +
                                                         k_flank + 1]
                        else:
                            kmer = sequence_obj.revcomp[reg_len - idx -
                                                        k_flank - 1:reg_len -
                                                        idx + k_flank]

                        #Save kmer for bias correction verification
                        pre_bias[strand].add_sequence(kmer, orig)
                        post_bias[strand].add_sequence(kmer, correct)

        #######################################
        ########    Write to queue    #########
        #######################################

        #Set size back to original
        for track in out_signals[reg_key]:
            for strand in out_signals[reg_key][track]:
                out_signals[reg_key][track][strand] = out_signals[reg_key][
                    track][strand][f_extend:-f_extend]

        #Calculate "both" if split_strands == False
        if params.split_strands == False:
            for track in out_signals[reg_key]:
                out_signals[reg_key][track]["both"] = out_signals[reg_key][
                    track]["forward"] + out_signals[reg_key][track]["reverse"]

        #Send to queue
        strands_to_write = ["forward", "reverse"
                            ] if params.split_strands == True else ["both"]
        for track in out_signals[reg_key]:

            #Send to writer per strand
            for strand in strands_to_write:
                key = "{0}:{1}".format(track, strand)
                logger.spam(
                    "Sending {0} signal from region {1} to writer queue".
                    format(key, reg_key))
                qs[key].put(
                    (key, reg_key, out_signals[reg_key][track][strand]))

        #Sent to qs - delete from this process
        out_signals[reg_key] = None

    bam_obj.close()
    fasta_obj.close()

    gc.collect()

    return ([pre_bias, post_bias])
Пример #17
0
def bias_estimation(regions_list, params):
    """ Estimates bias of insertions within regions """

    #Info on run
    bam_f = params.bam
    fasta_f = params.genome
    k_flank = params.k_flank
    bg_shift = params.bg_shift
    read_shift = params.read_shift
    L = 2 * k_flank + 1

    logger = TobiasLogger("", params.verbosity,
                          params.log_q)  #sending all logger calls to log_q

    #Open objects for reading
    bam_obj = pysam.AlignmentFile(bam_f, "rb")
    fasta_obj = pysam.FastaFile(fasta_f)
    chrom_lengths = dict(
        zip(bam_obj.references,
            bam_obj.lengths))  #Chromosome boundaries from bam_obj

    bias_obj = AtacBias(L, params.score_mat)

    strands = ["forward", "reverse"]

    #Estimate bias at each region
    for region in regions_list:

        read_lst = ReadList().from_bam(bam_obj, region)
        for read in read_lst:
            read.get_cutsite(read_shift)

        ## Kmer cutting bias ##
        if len(read_lst) > 0:

            #Extract sequence
            extended_region = region.extend_reg(
                k_flank + bg_shift)  #Extend to allow full kmers
            extended_region.check_boundary(chrom_lengths, "cut")
            sequence_obj = GenomicSequence(extended_region).from_fasta(
                fasta_obj)

            #Split reads forward/reverse
            for_lst, rev_lst = read_lst.split_strands()
            read_lst_strand = {"forward": for_lst, "reverse": rev_lst}
            logger.spam(
                "Region: {0}. Forward reads: {1}. Reverse reads: {2}".format(
                    region, len(for_lst), len(rev_lst)))

            for strand in strands:

                #Map reads to positions
                read_per_pos = {}
                for read in read_lst_strand[strand]:
                    if read.cigartuples is not None:
                        first_tuple = read.cigartuples[
                            -1] if read.is_reverse else read.cigartuples[0]
                        if first_tuple[0] == 0 and first_tuple[
                                1] > params.k_flank + max(
                                    np.abs(params.read_shift)
                                ):  #Only include non-clipped reads
                            read_per_pos[read.cutsite] = read_per_pos.get(
                                read.cutsite, []) + [read]

                #Get kmer for each position
                for cutsite in read_per_pos:
                    if cutsite > region.start and cutsite < region.end:  #only reads within borders
                        read = read_per_pos[cutsite][
                            0]  #use first read in list to establish kmer
                        no_cut = min(
                            len(read_per_pos[cutsite]), 10
                        )  #put cap on number of cuts to limit influence of outliers

                        read.get_kmer(sequence_obj, k_flank)

                        bias_obj.bias[strand].add_sequence(read.kmer, no_cut)
                        read.shift_cutsite(
                            -bg_shift
                        )  #upstream of read; ensures that bg is not within fragment
                        read.get_kmer(
                            sequence_obj,
                            k_flank)  #kmer updated to kmer for shifted read
                        bias_obj.bias[strand].add_background(read.kmer, no_cut)

                        bias_obj.no_reads += no_cut
    bam_obj.close()
    fasta_obj.close()

    return (bias_obj)  #object containing information collected on bias
Пример #18
0
def calculate_scores(regions, args):

	logger = TobiasLogger("", args.verbosity, args.log_q)

	pybw_signal = pyBigWig.open(args.signal) 	#cutsites signal
	pybw_header = pybw_signal.chroms()			
	chrom_lengths = {chrom: int(pybw_header[chrom]) for chrom in pybw_header}

	#Set flank to enable scoring in ends of regions
	flank = args.region_flank

	#Go through each region
	for i, region in enumerate(regions):

		logger.debug("Calculating scores for region: {0}".format(region))

		#Extend region with necessary flank
		region.extend_reg(flank)
		reg_key = (region.chrom, region.start+flank, region.end-flank)	#output region

		#Get bigwig signal in region
		signal = region.get_signal(pybw_signal)
		signal = np.nan_to_num(signal).astype("float64")

		#-------- Prepare signal for score calculation -------#
		if args.absolute:
			signal = np.abs(signal)

		if args.min_limit != None:
			signal[signal < args.min_limit] = args.min_limit
		if args.max_limit != None:
			signal[signal > args.max_limit] = args.max_limit

		#------------------ Calculate scores ----------------#
		if args.score == "sum":
			scores = fast_rolling_math(signal, args.window, "sum")

		elif args.score == "mean":
			scores = fast_rolling_math(signal, args.window, "mean")

		elif args.score == "footprint":
			scores = tobias_footprint_array(signal, args.flank_min, args.flank_max, args.fp_min, args.fp_max)		#numpy array

		elif args.score == "FOS":
			scores = FOS_score(signal, args.flank_min, args.flank_max, args.fp_min, args.fp_max)
			scores = -scores

		elif args.score == "none":
			scores = signal
		
		else:
			sys.exit("Scoring {0} not found".format(args.score))
		
		#----------------- Post-process scores --------------#
		
		#Smooth signal with args.smooth bp
		if args.smooth > 1:
			scores = fast_rolling_math(scores, args.smooth, "mean")

		#Remove ends to prevent overlap with other regions
		if flank > 0:
			scores = scores[flank:-flank]

		args.writer_qs["scores"].put(("scores", reg_key, scores))

	return(1)