def main():

	ref_inds = indexRef(REF)
	refList  = [n[0] for n in ref_inds]

	# how many times do we observe each trinucleotide in the reference (and input bed region, if present)?
	TRINUC_REF_COUNT = {}
	TRINUC_BED_COUNT = {}
	printBedWarning  = True
	# [(trinuc_a, trinuc_b)] = # of times we observed a mutation from trinuc_a into trinuc_b
	TRINUC_TRANSITION_COUNT = {}
	# total count of SNPs
	SNP_COUNT = 0
	# overall SNP transition probabilities
	SNP_TRANSITION_COUNT = {}
	# total count of indels, indexed by length
	INDEL_COUNT = {}
	# tabulate how much non-N reference sequence we've eaten through
	TOTAL_REFLEN = 0
	# detect variants that occur in a significant percentage of the input samples (pos,ref,alt,pop_fraction)
	COMMON_VARIANTS = []
	# tabulate how many unique donors we've encountered (this is useful for identifying common variants)
	TOTAL_DONORS = {}
	# identify regions that have significantly higher local mutation rates than the average
	HIGH_MUT_REGIONS = []

	# load and process variants in each reference sequence individually, for memory reasons...
	for refName in refList:

		if (refName not in REF_WHITELIST) and (not NO_WHITELIST):
			print refName,'is not in our whitelist, skipping...'
			continue

		print 'reading reference "'+refName+'"...'
		refSequence = getChrFromFasta(REF,ref_inds,refName).upper()
		TOTAL_REFLEN += len(refSequence) - refSequence.count('N')

		# list to be used for counting variants that occur multiple times in file (i.e. in multiple samples)
		VDAT_COMMON = []


		""" ##########################################################################
		###						COUNT TRINUCLEOTIDES IN REF						   ###
		########################################################################## """


		if MYBED != None:
			if printBedWarning:
				print "since you're using a bed input, we have to count trinucs in bed region even if you specified a trinuc count file for the reference..."
				printBedWarning = False
			if refName in MYBED[0]:
				refKey = refName
			elif ('chr' in refName) and (refName not in MYBED[0]) and (refName[3:] in MYBED[0]):
				refKey = refName[3:]
			elif ('chr' not in refName) and (refName not in MYBED[0]) and ('chr'+refName in MYBED[0]):
				refKey = 'chr'+refName
			if refKey in MYBED[0]:
				subRegions = [(MYBED[0][refKey][n],MYBED[0][refKey][n+1]) for n in xrange(0,len(MYBED[0][refKey]),2)]
				for sr in subRegions:
					for i in xrange(sr[0],sr[1]+1-2):
						trinuc = refSequence[i:i+3]
						if not trinuc in VALID_TRINUC:
							continue	# skip if trinuc contains invalid characters, or not in specified bed region
						if trinuc not in TRINUC_BED_COUNT:
							TRINUC_BED_COUNT[trinuc] = 0
						TRINUC_BED_COUNT[trinuc] += 1

		if not os.path.isfile(REF+'.trinucCounts'):
			print 'counting trinucleotides in reference...'
			for i in xrange(len(refSequence)-2):
				if i%1000000 == 0 and i > 0:
					print i,'/',len(refSequence)
					#break
				trinuc = refSequence[i:i+3]
				if not trinuc in VALID_TRINUC:
					continue	# skip if trinuc contains invalid characters
				if trinuc not in TRINUC_REF_COUNT:
					TRINUC_REF_COUNT[trinuc] = 0
				TRINUC_REF_COUNT[trinuc] += 1
		else:
			print 'skipping trinuc counts (for whole reference) because we found a file...'


		""" ##########################################################################
		###							READ INPUT VARIANTS							   ###
		########################################################################## """


		print 'reading input variants...'
		f = open(TSV,'r')
		isFirst = True
		for line in f:

			if IS_VCF and line[0] == '#':
				continue
			if isFirst:
				if IS_VCF:
					# hard-code index values based on expected columns in vcf
					(c1,c2,c3,m1,m2,m3) = (0,1,1,3,3,4)
				else:
					# determine columns of fields we're interested in
					splt = line.strip().split('\t')
					(c1,c2,c3) = (splt.index('chromosome'),splt.index('chromosome_start'),splt.index('chromosome_end'))
					(m1,m2,m3) = (splt.index('reference_genome_allele'),splt.index('mutated_from_allele'),splt.index('mutated_to_allele'))
					(d_id) = (splt.index('icgc_donor_id'))
				isFirst = False
				continue

			splt = line.strip().split('\t')
			# we have -1 because tsv/vcf coords are 1-based, and our reference string index is 0-based
			[chrName,chrStart,chrEnd] = [splt[c1],int(splt[c2])-1,int(splt[c3])-1]
			[allele_ref,allele_normal,allele_tumor] = [splt[m1].upper(),splt[m2].upper(),splt[m3].upper()]
			if IS_VCF:
				if len(allele_ref) != len(allele_tumor):
					# indels in tsv don't include the preserved first nucleotide, so lets trim the vcf alleles
					[allele_ref,allele_normal,allele_tumor] = [allele_ref[1:],allele_normal[1:],allele_tumor[1:]]
				if not allele_ref: allele_ref = '-'
				if not allele_normal: allele_normal = '-'
				if not allele_tumor: allele_tumor = '-'
				# if alternate alleles are present, lets just ignore this variant. I may come back and improve this later
				if ',' in allele_tumor:
					continue
				vcf_info = ';'+splt[7]+';'
			else:
				[donor_id] = [splt[d_id]]

			# if we encounter a multi-np (i.e. 3 nucl --> 3 different nucl), let's skip it for now...
			if ('-' not in allele_normal and '-' not in allele_tumor) and (len(allele_normal) > 1 or len(allele_tumor) > 1):
				print 'skipping a complex variant...'
				continue

			# to deal with '1' vs 'chr1' references, manually change names. this is hacky and bad.
			if 'chr' not in chrName:
				chrName = 'chr'+chrName
			if 'chr' not in refName:
				refName = 'chr'+refName
			# skip irrelevant variants
			if chrName != refName:
				continue

			# if variant is outside the regions we're interested in (if specified), skip it...
			if MYBED != None:
				refKey = refName
				if not refKey in MYBED[0] and refKey[3:] in MYBED[0]:	# account for 1 vs chr1, again...
					refKey = refKey[3:]
				if refKey not in MYBED[0]:
					inBed = False
				else:
					inBed = isInBed(MYBED[0][refKey],chrStart)
				if inBed != MYBED[1]:
					continue

			# we want only snps
			# so, no '-' characters allowed, and chrStart must be same as chrEnd
			if '-' not in allele_normal and '-' not in allele_tumor and chrStart == chrEnd:
				trinuc_ref = refSequence[chrStart-1:chrStart+2]
				if not trinuc_ref in VALID_TRINUC:
					continue	# skip ref trinuc with invalid characters
				# only consider positions where ref allele in tsv matches the nucleotide in our reference
				if allele_ref == trinuc_ref[1]:
					trinuc_normal    = refSequence[chrStart-1] + allele_normal + refSequence[chrStart+1]
					trinuc_tumor     = refSequence[chrStart-1] + allele_tumor + refSequence[chrStart+1]
					if not trinuc_normal in VALID_TRINUC or not trinuc_tumor in VALID_TRINUC:
						continue	# skip if mutation contains invalid char
					key = (trinuc_normal,trinuc_tumor)
					if key not in TRINUC_TRANSITION_COUNT:
						TRINUC_TRANSITION_COUNT[key] = 0
					TRINUC_TRANSITION_COUNT[key] += 1
					SNP_COUNT += 1
					key2 = (allele_normal,allele_tumor)
					if key2 not in SNP_TRANSITION_COUNT:
						SNP_TRANSITION_COUNT[key2] = 0
					SNP_TRANSITION_COUNT[key2] += 1

					if IS_VCF:
						myPopFreq = VCF_DEFAULT_POP_FREQ
						if ';CAF=' in vcf_info:
							cafStr = re.findall(r";CAF=.*?(?=;)",vcf_info)[0]
							if ',' in cafStr:
								myPopFreq = float(cafStr[5:].split(',')[1])
						VDAT_COMMON.append((chrStart,allele_ref,allele_normal,allele_tumor,myPopFreq))
					else:
						VDAT_COMMON.append((chrStart,allele_ref,allele_normal,allele_tumor))
						TOTAL_DONORS[donor_id] = True
				else:
					print '\nError: ref allele in variant call does not match reference.\n'
					exit(1)

			# now let's look for indels...
			if '-' in allele_normal: len_normal = 0
			else: len_normal = len(allele_normal)
			if '-' in allele_tumor: len_tumor = 0
			else: len_tumor = len(allele_tumor)
			if len_normal != len_tumor:
				indel_len = len_tumor - len_normal
				if indel_len not in INDEL_COUNT:
					INDEL_COUNT[indel_len] = 0
				INDEL_COUNT[indel_len] += 1

				if IS_VCF:
					myPopFreq = VCF_DEFAULT_POP_FREQ
					if ';CAF=' in vcf_info:
						cafStr = re.findall(r";CAF=.*?(?=;)",vcf_info)[0]
						if ',' in cafStr:
							myPopFreq = float(cafStr[5:].split(',')[1])
					VDAT_COMMON.append((chrStart,allele_ref,allele_normal,allele_tumor,myPopFreq))
				else:
					VDAT_COMMON.append((chrStart,allele_ref,allele_normal,allele_tumor))
					TOTAL_DONORS[donor_id] = True
		f.close()

		# if we didn't find anything, skip ahead along to the next reference sequence
		if not len(VDAT_COMMON):
			print 'Found no variants for this reference, moving along...'
			continue

		#
		# identify common mutations
		#
		percentile_var = 95
		if IS_VCF:
			minVal = np.percentile([n[4] for n in VDAT_COMMON],percentile_var)
			for k in sorted(VDAT_COMMON):
				if k[4] >= minVal:
					COMMON_VARIANTS.append((refName,k[0],k[1],k[3],k[4]))
			VDAT_COMMON = {(n[0],n[1],n[2],n[3]):n[4] for n in VDAT_COMMON}
		else:
			N_DONORS = len(TOTAL_DONORS)
			VDAT_COMMON = list_2_countDict(VDAT_COMMON)
			minVal = int(np.percentile(VDAT_COMMON.values(),percentile_var))
			for k in sorted(VDAT_COMMON.keys()):
				if VDAT_COMMON[k] >= minVal:
					COMMON_VARIANTS.append((refName,k[0],k[1],k[3],VDAT_COMMON[k]/float(N_DONORS)))

		#
		# identify areas that have contained significantly higher random mutation rates
		#
		dist_thresh      = 2000
		percentile_clust = 97
		qptn             = 1000
		# identify regions with disproportionately more variants in them
		VARIANT_POS = sorted([n[0] for n in VDAT_COMMON.keys()])
		clustered_pos = clusterList(VARIANT_POS,dist_thresh)
		byLen  = [(len(clustered_pos[i]),min(clustered_pos[i]),max(clustered_pos[i]),i) for i in xrange(len(clustered_pos))]
		#byLen  = sorted(byLen,reverse=True)
		#minLen = int(np.percentile([n[0] for n in byLen],percentile_clust))
		#byLen  = [n for n in byLen if n[0] >= minLen]
		candidate_regions = []
		for n in byLen:
			bi = int((n[1]-dist_thresh)/float(qptn))*qptn
			bf = int((n[2]+dist_thresh)/float(qptn))*qptn
			candidate_regions.append((n[0]/float(bf-bi),max([0,bi]),min([len(refSequence),bf])))
		minVal = np.percentile([n[0] for n in candidate_regions],percentile_clust)
		for n in candidate_regions:
			if n[0] >= minVal:
				HIGH_MUT_REGIONS.append((refName,n[1],n[2],n[0]))
		# collapse overlapping regions
		for i in xrange(len(HIGH_MUT_REGIONS)-1,0,-1):
			if HIGH_MUT_REGIONS[i-1][2] >= HIGH_MUT_REGIONS[i][1] and HIGH_MUT_REGIONS[i-1][0] == HIGH_MUT_REGIONS[i][0]:
				avgMutRate = 0.5*HIGH_MUT_REGIONS[i-1][3]+0.5*HIGH_MUT_REGIONS[i][3]	# not accurate, but I'm lazy
				HIGH_MUT_REGIONS[i-1] = (HIGH_MUT_REGIONS[i-1][0], HIGH_MUT_REGIONS[i-1][1], HIGH_MUT_REGIONS[i][2], avgMutRate)
				del HIGH_MUT_REGIONS[i]

	#
	# if we didn't count ref trinucs because we found file, read in ref counts from file now
	#
	if os.path.isfile(REF+'.trinucCounts'):
		print 'reading pre-computed trinuc counts...'
		f = open(REF+'.trinucCounts','r')
		for line in f:
			splt = line.strip().split('\t')
			TRINUC_REF_COUNT[splt[0]] = int(splt[1])
		f.close()
	# otherwise, save trinuc counts to file, if desired
	elif SAVE_TRINUC:
		if MYBED != None:
			print 'unable to save trinuc counts to file because using input bed region...'
		else:
			print 'saving trinuc counts to file...'
			f = open(REF+'.trinucCounts','w')
			for trinuc in sorted(TRINUC_REF_COUNT.keys()):
				f.write(trinuc+'\t'+str(TRINUC_REF_COUNT[trinuc])+'\n')
			f.close()

	#
	# if using an input bed region, make necessary adjustments to trinuc ref counts based on the bed region trinuc counts
	#
	if MYBED != None:
		if MYBED[1] == True:	# we are restricting our attention to bed regions, so ONLY use bed region trinuc counts
			TRINUC_REF_COUNT = TRINUC_BED_COUNT
		else:					# we are only looking outside bed regions, so subtract bed region trinucs from entire reference trinucs
			for k in TRINUC_REF_COUNT.keys():
				if k in TRINUC_BED_COUNT:
					TRINUC_REF_COUNT[k] -= TRINUC_BED_COUNT[k]

	# if for some reason we didn't find any valid input variants, exit gracefully...
	totalVar = SNP_COUNT + sum(INDEL_COUNT.values())
	if totalVar == 0:
		print '\nError: No valid variants were found, model could not be created. (Are you using the correct reference?)\n'
		exit(1)

	""" ##########################################################################
	###							COMPUTE PROBABILITIES						   ###
	########################################################################## """


	#for k in sorted(TRINUC_REF_COUNT.keys()):
	#		print k, TRINUC_REF_COUNT[k]
	#
	#for k in sorted(TRINUC_TRANSITION_COUNT.keys()):
	#	print k, TRINUC_TRANSITION_COUNT[k]

	# frequency that each trinuc mutated into anything else
	TRINUC_MUT_PROB = {}
	# frequency that a trinuc mutates into another trinuc, given that it mutated
	TRINUC_TRANS_PROBS = {}
	# frequency of snp transitions, given a snp occurs.
	SNP_TRANS_FREQ = {}

	for trinuc in sorted(TRINUC_REF_COUNT.keys()):
		myCount = 0
		for k in sorted(TRINUC_TRANSITION_COUNT.keys()):
			if k[0] == trinuc:
				myCount += TRINUC_TRANSITION_COUNT[k]
		TRINUC_MUT_PROB[trinuc] = myCount / float(TRINUC_REF_COUNT[trinuc])
		for k in sorted(TRINUC_TRANSITION_COUNT.keys()):
			if k[0] == trinuc:
				TRINUC_TRANS_PROBS[k] = TRINUC_TRANSITION_COUNT[k] / float(myCount)

	for n1 in VALID_NUCL:
		rollingTot = sum([SNP_TRANSITION_COUNT[(n1,n2)] for n2 in VALID_NUCL if (n1,n2) in SNP_TRANSITION_COUNT])
		for n2 in VALID_NUCL:
			key2 = (n1,n2)
			if key2 in SNP_TRANSITION_COUNT:
				SNP_TRANS_FREQ[key2] = SNP_TRANSITION_COUNT[key2] / float(rollingTot)

	# compute average snp and indel frequencies
	SNP_FREQ       = SNP_COUNT/float(totalVar)
	AVG_INDEL_FREQ = 1.-SNP_FREQ
	INDEL_FREQ     = {k:(INDEL_COUNT[k]/float(totalVar))/AVG_INDEL_FREQ for k in INDEL_COUNT.keys()}
	if MYBED != None:
		if MYBED[1] == True:
			AVG_MUT_RATE = totalVar/float(getTrackLen(MYBED[0]))
		else:
			AVG_MUT_RATE = totalVar/float(TOTAL_REFLEN - getTrackLen(MYBED[0]))
	else:
		AVG_MUT_RATE = totalVar/float(TOTAL_REFLEN)

	#
	#	if values weren't found in data, appropriately append null entries
	#
	printTrinucWarning = False
	for trinuc in VALID_TRINUC:
		trinuc_mut = [trinuc[0]+n+trinuc[2] for n in VALID_NUCL if n != trinuc[1]]
		if trinuc not in TRINUC_MUT_PROB:
			TRINUC_MUT_PROB[trinuc] = 0.
			printTrinucWarning = True
		for trinuc2 in trinuc_mut:
			if (trinuc,trinuc2) not in TRINUC_TRANS_PROBS:
				TRINUC_TRANS_PROBS[(trinuc,trinuc2)] = 0.
				printTrinucWarning = True
	if printTrinucWarning:
		print 'Warning: Some trinucleotides transitions were not encountered in the input dataset, probabilities of 0.0 have been assigned to these events.'

	#
	#	print some stuff
	#
	for k in sorted(TRINUC_MUT_PROB.keys()):
		print 'p('+k+' mutates) =',TRINUC_MUT_PROB[k]

	for k in sorted(TRINUC_TRANS_PROBS.keys()):
		print 'p('+k[0]+' --> '+k[1]+' | '+k[0]+' mutates) =',TRINUC_TRANS_PROBS[k]

	for k in sorted(INDEL_FREQ.keys()):
		if k > 0:
			print 'p(ins length = '+str(abs(k))+' | indel occurs) =',INDEL_FREQ[k]
		else:
			print 'p(del length = '+str(abs(k))+' | indel occurs) =',INDEL_FREQ[k]

	for k in sorted(SNP_TRANS_FREQ.keys()):
		print 'p('+k[0]+' --> '+k[1]+' | SNP occurs) =',SNP_TRANS_FREQ[k]

	#for n in COMMON_VARIANTS:
	#	print n

	#for n in HIGH_MUT_REGIONS:
	#	print n

	print 'p(snp)   =',SNP_FREQ
	print 'p(indel) =',AVG_INDEL_FREQ
	print 'overall average mut rate:',AVG_MUT_RATE
	print 'total variants processed:',totalVar

	#
	# save variables to file
	#
	if SKIP_COMMON:
		OUT_DICT = {'AVG_MUT_RATE':AVG_MUT_RATE,
		            'SNP_FREQ':SNP_FREQ,
		            'SNP_TRANS_FREQ':SNP_TRANS_FREQ,
		            'INDEL_FREQ':INDEL_FREQ,
		            'TRINUC_MUT_PROB':TRINUC_MUT_PROB,
		            'TRINUC_TRANS_PROBS':TRINUC_TRANS_PROBS}
	else:
		OUT_DICT = {'AVG_MUT_RATE':AVG_MUT_RATE,
		            'SNP_FREQ':SNP_FREQ,
		            'SNP_TRANS_FREQ':SNP_TRANS_FREQ,
		            'INDEL_FREQ':INDEL_FREQ,
		            'TRINUC_MUT_PROB':TRINUC_MUT_PROB,
		            'TRINUC_TRANS_PROBS':TRINUC_TRANS_PROBS,
		            'COMMON_VARIANTS':COMMON_VARIANTS,
		            'HIGH_MUT_REGIONS':HIGH_MUT_REGIONS}
	pickle.dump( OUT_DICT, open( OUT_PICKLE, "wb" ) )
Example #2
0
def main():

	# index reference
	refIndex = indexRef(REFERENCE)
	if PAIRED_END:
		N_HANDLING = ('random',FRAGMENT_SIZE)
	else:
		N_HANDLING = ('ignore',READLEN)
	indices_by_refName = {refIndex[n][0]:n for n in xrange(len(refIndex))}

	# parse input variants, if present
	inputVariants = []
	if INPUT_VCF != None:
		if CANCER:
			(sampNames, inputVariants) = parseVCF(INPUT_VCF,tumorNormal=True,ploidy=PLOIDS)
			tumorInd  = sampNames.index('TUMOR')
			normalInd = sampNames.index('NORMAL')
		else:
			(sampNames, inputVariants) = parseVCF(INPUT_VCF,ploidy=PLOIDS)
		for k in sorted(inputVariants.keys()):
			inputVariants[k].sort()

	# parse input targeted regions, if present
	refList      = [n[0] for n in refIndex]
	inputRegions = {}
	if INPUT_BED != None:
		f = open(INPUT_BED,'r')
		for line in f:
			[myChr,pos1,pos2] = line.strip().split('\t')[:3]
			if myChr not in inputRegions:
				inputRegions[myChr] = [-1]
			inputRegions[myChr].extend([int(pos1),int(pos2)])
		f.close()
		# some validation
		nInBedOnly = 0
		nInRefOnly = 0
		for k in refList:
			if k not in inputRegions:
				nInRefOnly += 1
		for k in inputRegions.keys():
			if not k in refList:
				nInBedOnly += 1
				del inputRegions[k]
		if nInRefOnly > 0:
			print 'Warning: Reference contains sequences not found in targeted regions BED file.'
		if nInBedOnly > 0:
			print 'Warning: Targeted regions BED file contains sequence names not found in reference (regions ignored).'
	# parse discard bed similarly
	discardRegions = {}
	if DISCARD_BED != None:
		f = open(DISCARD_BED,'r')
		for line in f:
			[myChr,pos1,pos2] = line.strip().split('\t')[:3]
			if myChr not in discardRegions:
				discardRegions[myChr] = [-1]
			discardRegions[myChr].extend([int(pos1),int(pos2)])
		f.close()

	# parse input mutation rate rescaling regions, if present
	mutRateRegions = {}
	mutRateValues  = {}
	if MUT_BED != None:
		with open(MUT_BED,'r') as f:
			for line in f:
				[myChr,pos1,pos2,metaData] = line.strip().split('\t')[:4]
				mutStr = re.findall(r"MUT_RATE=.*?(?=;)",metaData+';')
				(pos1,pos2) = (int(pos1),int(pos2))
				if len(mutStr) and (pos2-pos1) > 1:
					# mutRate = #_mutations / length_of_region, let's bound it by a reasonable amount
					mutRate = max([0.0,min([float(mutStr[0][9:]),0.3])])
					if myChr not in mutRateRegions:
						mutRateRegions[myChr] = [-1]
						mutRateValues[myChr]  = [0.0]
					mutRateRegions[myChr].extend([pos1,pos2])
					mutRateValues.extend([mutRate*(pos2-pos1)]*2)

	# initialize output files (part I)
	bamHeader = None
	if SAVE_BAM:
		bamHeader = [copy.deepcopy(refIndex)]
	vcfHeader = None
	if SAVE_VCF:
		vcfHeader = [REFERENCE]
	
	# If processing jobs in parallel, precompute the independent regions that can be process separately
	if NJOBS > 1:
		parallelRegionList  = getAllRefRegions(REFERENCE,refIndex,N_HANDLING,saveOutput=SAVE_NON_N)
		(myRefs, myRegions) = partitionRefRegions(parallelRegionList,refIndex,MYJOB,NJOBS)
		if not len(myRegions):
			print 'This job id has no regions to process, exiting...'
			exit(1)
		for i in xrange(len(refIndex)-1,-1,-1):	# delete reference not used in our job
			if not refIndex[i][0] in myRefs:
				del refIndex[i]
		# if value of NJOBS is too high, let's change it to the maximum possible, to avoid output filename confusion
		corrected_nJobs = min([NJOBS,sum([len(n) for n in parallelRegionList.values()])])
	else:
		corrected_nJobs = 1

	# initialize output files (part II)
	if CANCER:
		OFW = OutputFileWriter(OUT_PREFIX+'_normal',paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT,noFASTQ=NO_FASTQ,FASTA_instead=FASTA_INSTEAD)
		OFW_CANCER = OutputFileWriter(OUT_PREFIX+'_tumor',paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT,jobTuple=(MYJOB,corrected_nJobs),noFASTQ=NO_FASTQ,FASTA_instead=FASTA_INSTEAD)
	else:
		OFW = OutputFileWriter(OUT_PREFIX,paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT,jobTuple=(MYJOB,corrected_nJobs),noFASTQ=NO_FASTQ,FASTA_instead=FASTA_INSTEAD)
	OUT_PREFIX_NAME = OUT_PREFIX.split('/')[-1]


	"""************************************************
	****        LET'S GET THIS PARTY STARTED...
	************************************************"""


	readNameCount = 1	# keep track of the number of reads we've sampled, for read-names
	unmapped_records = []

	for RI in xrange(len(refIndex)):

		# read in reference sequence and notate blocks of Ns
		(refSequence,N_regions) = readRef(REFERENCE,refIndex[RI],N_HANDLING)

		# if we're processing jobs in parallel only take the regions relevant for the current job
		if NJOBS > 1:
			for i in xrange(len(N_regions['non_N'])-1,-1,-1):
				if not (refIndex[RI][0],N_regions['non_N'][i][0],N_regions['non_N'][i][1]) in myRegions:
					del N_regions['non_N'][i]

		# count total bp we'll be spanning so we can get an idea of how far along we are (for printing progress indicators)
		total_bp_span   = sum([n[1]-n[0] for n in N_regions['non_N']])
		currentProgress = 0
		currentPercent  = 0
		havePrinted100  = False

		# prune invalid input variants, e.g variants that:
		#		- try to delete or alter any N characters
		#		- don't match the reference base at their specified position
		#		- any alt allele contains anything other than allowed characters
		validVariants = []
		nSkipped = [0,0,0]
		if refIndex[RI][0] in inputVariants:
			for n in inputVariants[refIndex[RI][0]]:
				span = (n[0],n[0]+len(n[1]))
				rseq = str(refSequence[span[0]-1:span[1]-1])	# -1 because going from VCF coords to array coords
				anyBadChr = any((nn not in ALLOWED_NUCL) for nn in [item for sublist in n[2] for item in sublist])
				if rseq != n[1]:
					nSkipped[0] += 1
					continue
				elif 'N' in rseq:
					nSkipped[1] += 1
					continue
				elif anyBadChr:
					nSkipped[2] += 1
					continue
				#if bisect.bisect(N_regions['big'],span[0])%2 or bisect.bisect(N_regions['big'],span[1])%2:
				#	continue
				validVariants.append(n)
			print 'found',len(validVariants),'valid variants for '+refIndex[RI][0]+' in input VCF...'
			if any(nSkipped):
				print sum(nSkipped),'variants skipped...'
				print ' - ['+str(nSkipped[0])+'] ref allele does not match reference'
				print ' - ['+str(nSkipped[1])+'] attempting to insert into N-region'
				print ' - ['+str(nSkipped[2])+'] alt allele contains non-ACGT characters'


		# add large random structural variants
		#
		#	TBD!!!


		# determine sampling windows based on read length, large N regions, and structural mutations.
		# in order to obtain uniform coverage, windows should overlap by:
		#		- READLEN, if single-end reads
		#		- FRAGMENT_SIZE (mean), if paired-end reads
		# ploidy is fixed per large sampling window,
		# coverage distributions due to GC% and targeted regions are specified within these windows
		samplingWindows  = []
		ALL_VARIANTS_OUT = {}
		sequences        = None
		if PAIRED_END:
			targSize = WINDOW_TARGET_SCALE*FRAGMENT_SIZE
			overlap  = FRAGMENT_SIZE
			overlap_minWindowSize = max(FRAGLEN_DISTRIBUTION.values) + 10
		else:
			targSize = WINDOW_TARGET_SCALE*READLEN
			overlap  = READLEN
			overlap_minWindowSize = READLEN + 10

		print '--------------------------------'
		if ONLY_VCF:
			print 'generating vcf...'
		else:
			print 'sampling reads...'
		tt = time.time()

		for i in xrange(len(N_regions['non_N'])):
			(pi,pf) = N_regions['non_N'][i]
			nTargWindows = max([1,(pf-pi)/targSize])
			bpd = int((pf-pi)/float(nTargWindows))
			#bpd += GC_WINDOW_SIZE - bpd%GC_WINDOW_SIZE

			#print len(refSequence), (pi,pf), nTargWindows
			#print structuralVars

			# if for some reason our region is too small to process, skip it! (sorry)
			if nTargWindows == 1 and (pf-pi) < overlap_minWindowSize:
				#print 'Does this ever happen?'
				continue

			start = pi
			end   = min([start+bpd,pf])
			#print '------------------RAWR:', (pi,pf), nTargWindows, bpd
			varsFromPrevOverlap = []
			varsCancerFromPrevOverlap = []
			vindFromPrev = 0
			isLastTime = False
			havePrinted100 = False

			while True:
				
				# which inserted variants are in this window?
				varsInWindow = []
				updated = False
				for j in xrange(vindFromPrev,len(validVariants)):
					vPos = validVariants[j][0]
					if vPos > start and vPos < end:	# update: changed >= to >, so variant cannot be inserted in first position
						varsInWindow.append(tuple([vPos-1]+list(validVariants[j][1:])))	# vcf --> array coords
					if vPos >= end-overlap-1 and updated == False:
						updated = True
						vindFromPrev = j
					if vPos >= end:
						break

				# determine which structural variants will affect our sampling window positions
				structuralVars = []
				for n in varsInWindow:
					bufferNeeded = max([max([abs(len(n[1])-len(alt_allele)),1]) for alt_allele in n[2]]) # change: added abs() so that insertions are also buffered.
					structuralVars.append((n[0]-1,bufferNeeded))	# -1 because going from VCF coords to array coords

				# adjust end-position of window based on inserted structural mutations
				buffer_added = 0
				keepGoing = True
				while keepGoing:
					keepGoing = False
					for n in structuralVars:
						# adding "overlap" here to prevent SVs from being introduced in overlap regions
						# (which can cause problems if random mutations from the previous window land on top of them)
						delta = (end-1) - (n[0] + n[1]) - 2 - overlap
						if delta < 0:
							#print 'DELTA:', delta, 'END:', end, '-->',
							buffer_added = -delta
							end += buffer_added
							####print end
							keepGoing = True
							break
				next_start = end-overlap
				next_end   = min([next_start+bpd,pf])
				if next_end-next_start < bpd:
					end = next_end
					isLastTime = True

				# print progress indicator
				#print 'PROCESSING WINDOW:',(start,end), [buffer_added], 'next:', (next_start,next_end), 'isLastTime:', isLastTime
				currentProgress += end-start
				newPercent = int((currentProgress*100)/float(total_bp_span))
				if newPercent > currentPercent:
					if newPercent <= 99 or (newPercent == 100 and not havePrinted100):
						sys.stdout.write(str(newPercent)+'% ')
						sys.stdout.flush()
					currentPercent = newPercent
					if currentPercent == 100:
						havePrinted100 = True

				skip_this_window = False

				# compute coverage modifiers
				coverage_avg = None
				coverage_dat = [GC_WINDOW_SIZE,GC_SCALE_VAL,[]]
				target_hits  = 0
				if INPUT_BED == None:
					coverage_dat[2] = [1.0]*(end-start)
				else:
					if refIndex[RI][0] not in inputRegions:
						coverage_dat[2] = [OFFTARGET_SCALAR]*(end-start)
					else:
						for j in xrange(start,end):
							if not(bisect.bisect(inputRegions[refIndex[RI][0]],j)%2):
								coverage_dat[2].append(1.0)
								target_hits += 1
							else:
								coverage_dat[2].append(OFFTARGET_SCALAR)

				# offtarget and we're not interested?
				if OFFTARGET_DISCARD and target_hits <= READLEN:
					coverage_avg = 0.0
					skip_this_window = True

				#print len(coverage_dat[2]), sum(coverage_dat[2])
				if sum(coverage_dat[2]) < LOW_COV_THRESH:
					coverage_avg = 0.0
					skip_this_window = True

				# check for small window sizes
				if (end-start) < overlap_minWindowSize:
					skip_this_window = True

				if skip_this_window:
					# skip window, save cpu time
					start = next_start
					end   = next_end
					if isLastTime:
						break
					if end >= pf:
						isLastTime = True
					varsFromPrevOverlap = []
					continue

				# construct sequence data that we will sample reads from
				if sequences == None:
					sequences = SequenceContainer(start,refSequence[start:end],PLOIDS,overlap,READLEN,[MUT_MODEL]*PLOIDS,MUT_RATE,onlyVCF=ONLY_VCF)
				else:
					sequences.update(start,refSequence[start:end],PLOIDS,overlap,READLEN,[MUT_MODEL]*PLOIDS,MUT_RATE)

				# insert variants
				sequences.insert_mutations(varsFromPrevOverlap + varsInWindow)
				all_inserted_variants = sequences.random_mutations()
				#print all_inserted_variants

				# init coverage
				if sum(coverage_dat[2]) >= LOW_COV_THRESH:
					if PAIRED_END:
						coverage_avg = sequences.init_coverage(tuple(coverage_dat),fragDist=FRAGLEN_DISTRIBUTION)
					else:
						coverage_avg = sequences.init_coverage(tuple(coverage_dat))

				# unused cancer stuff
				if CANCER:
					tumor_sequences = SequenceContainer(start,refSequence[start:end],PLOIDS,overlap,READLEN,[CANCER_MODEL]*PLOIDS,MUT_RATE,coverage_dat)
					tumor_sequences.insert_mutations(varsCancerFromPrevOverlap + all_inserted_variants)
					all_cancer_variants = tumor_sequences.random_mutations()

				# which variants do we need to keep for next time (because of window overlap)?
				varsFromPrevOverlap       = []
				varsCancerFromPrevOverlap = []
				for n in all_inserted_variants:
					if n[0] >= end-overlap-1:
						varsFromPrevOverlap.append(n)
				if CANCER:
					for n in all_cancer_variants:
						if n[0] >= end-overlap-1:
							varsCancerFromPrevOverlap.append(n)

				# if we're only producing VCF, no need to go through the hassle of generating reads
				if ONLY_VCF:
					pass
				else:
					windowSpan = end-start
					if PAIRED_END:
						if FORCE_COVERAGE:
							readsToSample = int((windowSpan*float(COVERAGE))/(2*READLEN))+1
						else:
							readsToSample = int((windowSpan*float(COVERAGE)*coverage_avg)/(2*READLEN))+1
					else:
						if FORCE_COVERAGE:
							readsToSample = int((windowSpan*float(COVERAGE))/READLEN)+1
						else:
							readsToSample = int((windowSpan*float(COVERAGE)*coverage_avg)/READLEN)+1

					# if coverage is so low such that no reads are to be sampled, skip region
					#      (i.e., remove buffer of +1 reads we add to every window)
					if readsToSample == 1 and sum(coverage_dat[2]) < LOW_COV_THRESH:
						readsToSample = 0

					# sample reads
					ASDF2_TT = time.time()
					for i in xrange(readsToSample):

						isUnmapped = []
						if PAIRED_END:
							myFraglen = FRAGLEN_DISTRIBUTION.sample()
							myReadData = sequences.sample_read(SE_CLASS,myFraglen)
							if myReadData == None:	# skip if we failed to find a valid position to sample read
								continue
							if myReadData[0][0] == None:
								isUnmapped.append(True)
							else:
								isUnmapped.append(False)
								myReadData[0][0] += start	# adjust mapping position based on window start
							if myReadData[1][0] == None:
								isUnmapped.append(True)
							else:
								isUnmapped.append(False)
								myReadData[1][0] += start
						else:
							myReadData = sequences.sample_read(SE_CLASS)
							if myReadData == None:	# skip if we failed to find a valid position to sample read
								continue
							if myReadData[0][0] == None:	# unmapped read (lives in large insertion)
								isUnmapped = [True]
							else:
								isUnmapped = [False]
								myReadData[0][0] += start	# adjust mapping position based on window start

						# are we discarding offtargets?
						outside_boundaries = []
						if OFFTARGET_DISCARD and INPUT_BED != None:
							outside_boundaries += [bisect.bisect(inputRegions[refIndex[RI][0]],n[0])%2 for n in myReadData]
							outside_boundaries += [bisect.bisect(inputRegions[refIndex[RI][0]],n[0]+len(n[2]))%2 for n in myReadData]
						if DISCARD_BED != None:
							outside_boundaries += [bisect.bisect(discardRegions[refIndex[RI][0]],n[0])%2 for n in myReadData]
							outside_boundaries += [bisect.bisect(discardRegions[refIndex[RI][0]],n[0]+len(n[2]))%2 for n in myReadData]
						if len(outside_boundaries) and any(outside_boundaries):
							continue

						if NJOBS > 1:
							myReadName = OUT_PREFIX_NAME+'-j'+str(MYJOB)+'-'+refIndex[RI][0]+'-r'+str(readNameCount)
						else:
							myReadName = OUT_PREFIX_NAME+'-'+refIndex[RI][0]+'-'+str(readNameCount)
						readNameCount += len(myReadData)

						# if desired, replace all low-quality bases with Ns
						if N_MAX_QUAL > -1:
							for j in xrange(len(myReadData)):
								myReadString = [n for n in myReadData[j][2]]
								for k in xrange(len(myReadData[j][3])):
									adjusted_qual = ord(myReadData[j][3][k])-SE_CLASS.offQ
									if adjusted_qual <= N_MAX_QUAL:
										myReadString[k] = 'N'
								myReadData[j][2] = ''.join(myReadString)

						# flip a coin, are we forward or reverse strand?
						isForward = (random.random() < 0.5)

						# if read (or read + mate for PE) are unmapped, put them at end of bam file
						if all(isUnmapped):
							if PAIRED_END:
								if isForward:
									flag1 = sam_flag(['paired','unmapped','mate_unmapped','first','mate_reverse'])
									flag2 = sam_flag(['paired','unmapped','mate_unmapped','second','reverse'])
								else:
									flag1 = sam_flag(['paired','unmapped','mate_unmapped','second','mate_reverse'])
									flag2 = sam_flag(['paired','unmapped','mate_unmapped','first','reverse'])
								unmapped_records.append((myReadName+'/1',myReadData[0],flag1))
								unmapped_records.append((myReadName+'/2',myReadData[1],flag2))
							else:
								flag1 = sam_flag(['unmapped'])
								unmapped_records.append((myReadName+'/1',myReadData[0],flag1))

						myRefIndex = indices_by_refName[refIndex[RI][0]]
						
						#
						# write SE output
						#
						if len(myReadData) == 1:
							if NO_FASTQ != True:
								if isForward:
									OFW.writeFASTQRecord(myReadName,myReadData[0][2],myReadData[0][3])
								else:
									OFW.writeFASTQRecord(myReadName,RC(myReadData[0][2]),myReadData[0][3][::-1])
							if SAVE_BAM:
								if isUnmapped[0] == False:
									if isForward:
										flag1 = 0
										OFW.writeBAMRecord(myRefIndex, myReadName, myReadData[0][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=flag1)
									else:
										flag1 = sam_flag(['reverse'])
										OFW.writeBAMRecord(myRefIndex, myReadName, myReadData[0][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=flag1)
						#
						# write PE output
						#
						elif len(myReadData) == 2:
							if NO_FASTQ != True:
								OFW.writeFASTQRecord(myReadName,myReadData[0][2],myReadData[0][3],read2=myReadData[1][2],qual2=myReadData[1][3],orientation=isForward)
							if SAVE_BAM:
								if isUnmapped[0] == False and isUnmapped[1] == False:
									if isForward:
										flag1 = sam_flag(['paired','proper','first','mate_reverse'])
										flag2 = sam_flag(['paired','proper','second','reverse'])
									else:
										flag1 = sam_flag(['paired','proper','second','mate_reverse'])
										flag2 = sam_flag(['paired','proper','first','reverse'])
									OFW.writeBAMRecord(myRefIndex, myReadName, myReadData[0][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=flag1, matePos=myReadData[1][0])
									OFW.writeBAMRecord(myRefIndex, myReadName, myReadData[1][0], myReadData[1][1], myReadData[1][2], myReadData[1][3], samFlag=flag2, matePos=myReadData[0][0])
								elif isUnmapped[0] == False and isUnmapped[1] == True:
									if isForward:
										flag1 = sam_flag(['paired','first', 'mate_unmapped', 'mate_reverse'])
										flag2 = sam_flag(['paired','second', 'unmapped', 'reverse'])
									else:
										flag1 = sam_flag(['paired','second', 'mate_unmapped', 'mate_reverse'])
										flag2 = sam_flag(['paired','first', 'unmapped', 'reverse'])
									OFW.writeBAMRecord(myRefIndex, myReadName, myReadData[0][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=flag1, matePos=myReadData[0][0])
									OFW.writeBAMRecord(myRefIndex, myReadName, myReadData[0][0], myReadData[1][1], myReadData[1][2], myReadData[1][3], samFlag=flag2, matePos=myReadData[0][0], alnMapQual=0)
								elif isUnmapped[0] == True and isUnmapped[1] == False:
									if isForward:
										flag1 = sam_flag(['paired','first', 'unmapped', 'mate_reverse'])
										flag2 = sam_flag(['paired','second', 'mate_unmapped', 'reverse'])
									else:
										flag1 = sam_flag(['paired','second', 'unmapped', 'mate_reverse'])
										flag2 = sam_flag(['paired','first', 'mate_unmapped', 'reverse'])
									OFW.writeBAMRecord(myRefIndex, myReadName, myReadData[1][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=flag1, matePos=myReadData[1][0], alnMapQual=0)
									OFW.writeBAMRecord(myRefIndex, myReadName, myReadData[1][0], myReadData[1][1], myReadData[1][2], myReadData[1][3], samFlag=flag2, matePos=myReadData[1][0])
						else:
							print '\nError: Unexpected number of reads generated...\n'
							exit(1)
					#print 'READS:',time.time()-ASDF2_TT

					if not isLastTime:
						OFW.flushBuffers(bamMax=next_start)
					else:
						OFW.flushBuffers(bamMax=end+1)

				# tally up all the variants that got successfully introduced
				for n in all_inserted_variants:
					ALL_VARIANTS_OUT[n] = True

				# prepare indices of next window
				start = next_start
				end   = next_end
				if isLastTime:
					break
				if end >= pf:
					isLastTime = True

		if currentPercent != 100 and not havePrinted100:
			print '100%'
		else:
			print ''
		if ONLY_VCF:
			print 'VCF generation completed in',
		else:
			print 'Read sampling completed in',
		print int(time.time()-tt),'(sec)'

		# write all output variants for this reference
		if SAVE_VCF:
			print 'Writing output VCF...'
			for k in sorted(ALL_VARIANTS_OUT.keys()):
				currentRef = refIndex[RI][0]
				myID       = '.'
				myQual     = '.'
				myFilt     = 'PASS'
				# k[0] + 1 because we're going back to 1-based vcf coords
				OFW.writeVCFRecord(currentRef, str(int(k[0])+1), myID, k[1], k[2], myQual, myFilt, k[4])

		#break

	# write unmapped reads to bam file
	if SAVE_BAM and len(unmapped_records):
		print 'writing unmapped reads to bam file...'
		for umr in unmapped_records:
			if PAIRED_END:
				OFW.writeBAMRecord(-1, umr[0], 0, umr[1][1], umr[1][2], umr[1][3], samFlag=umr[2], matePos=0, alnMapQual=0)
			else:
				OFW.writeBAMRecord(-1, umr[0], 0, umr[1][1], umr[1][2], umr[1][3], samFlag=umr[2], alnMapQual=0)

	# close output files
	OFW.closeFiles()
	if CANCER:
		OFW_CANCER.closeFiles()
Example #3
0
def main():

	ref_inds = indexRef(REF)
	refList  = [n[0] for n in ref_inds]

	# how many times do we observe each trinucleotide in the reference (and input bed region, if present)?
	TRINUC_REF_COUNT = {}
	TRINUC_BED_COUNT = {}
	printBedWarning  = True
	# [(trinuc_a, trinuc_b)] = # of times we observed a mutation from trinuc_a into trinuc_b
	TRINUC_TRANSITION_COUNT = {}
	# total count of SNPs
	SNP_COUNT = 0
	# overall SNP transition probabilities
	SNP_TRANSITION_COUNT = {}
	# total count of indels, indexed by length
	INDEL_COUNT = {}
	# tabulate how much non-N reference sequence we've eaten through
	TOTAL_REFLEN = 0
	# detect variants that occur in a significant percentage of the input samples (pos,ref,alt,pop_fraction)
	COMMON_VARIANTS = []
	# tabulate how many unique donors we've encountered (this is useful for identifying common variants)
	TOTAL_DONORS = {}
	# identify regions that have significantly higher local mutation rates than the average
	HIGH_MUT_REGIONS = []

	# load and process variants in each reference sequence individually, for memory reasons...
	for refName in refList:

		if refName not in REF_WHITELIST:
			print refName,'is not in our whitelist, skipping...'
			continue

		print 'reading reference "'+refName+'"...'
		refSequence = getChrFromFasta(REF,ref_inds,refName).upper()
		TOTAL_REFLEN += len(refSequence) - refSequence.count('N')

		# list to be used for counting variants that occur multiple times in file (i.e. in multiple samples)
		VDAT_COMMON = []


		""" ##########################################################################
		###						COUNT TRINUCLEOTIDES IN REF						   ###
		########################################################################## """


		if MYBED != None:
			if printBedWarning:
				print "since you're using a bed input, we have to count trinucs in bed region even if you specified a trinuc count file for the reference..."
				printBedWarning = False
			if refName in MYBED[0]:
				refKey = refName
			elif ('chr' in refName) and (refName not in MYBED[0]) and (refName[3:] in MYBED[0]):
				refKey = refName[3:]
			elif ('chr' not in refName) and (refName not in MYBED[0]) and ('chr'+refName in MYBED[0]):
				refKey = 'chr'+refName
			if refKey in MYBED[0]:
				subRegions = [(MYBED[0][refKey][n],MYBED[0][refKey][n+1]) for n in xrange(0,len(MYBED[0][refKey]),2)]
				for sr in subRegions:
					for i in xrange(sr[0],sr[1]+1-2):
						trinuc = refSequence[i:i+3]
						if not trinuc in VALID_TRINUC:
							continue	# skip if trinuc contains invalid characters, or not in specified bed region
						if trinuc not in TRINUC_BED_COUNT:
							TRINUC_BED_COUNT[trinuc] = 0
						TRINUC_BED_COUNT[trinuc] += 1

		if not os.path.isfile(REF+'.trinucCounts'):
			print 'counting trinucleotides in reference...'
			for i in xrange(len(refSequence)-2):
				if i%1000000 == 0 and i > 0:
					print i,'/',len(refSequence)
					#break
				trinuc = refSequence[i:i+3]
				if not trinuc in VALID_TRINUC:
					continue	# skip if trinuc contains invalid characters
				if trinuc not in TRINUC_REF_COUNT:
					TRINUC_REF_COUNT[trinuc] = 0
				TRINUC_REF_COUNT[trinuc] += 1
		else:
			print 'skipping trinuc counts (for whole reference) because we found a file...'


		""" ##########################################################################
		###							READ INPUT VARIANTS							   ###
		########################################################################## """


		print 'reading input variants...'
		f = open(TSV,'r')
		isFirst = True
		for line in f:

			if IS_VCF and line[0] == '#':
				continue
			if isFirst:
				if IS_VCF:
					# hard-code index values based on expected columns in vcf
					(c1,c2,c3,m1,m2,m3) = (0,1,1,3,3,4)
				else:
					# determine columns of fields we're interested in
					splt = line.strip().split('\t')
					(c1,c2,c3) = (splt.index('chromosome'),splt.index('chromosome_start'),splt.index('chromosome_end'))
					(m1,m2,m3) = (splt.index('reference_genome_allele'),splt.index('mutated_from_allele'),splt.index('mutated_to_allele'))
					(d_id) = (splt.index('icgc_donor_id'))
				isFirst = False
				continue

			splt = line.strip().split('\t')
			# we have -1 because tsv/vcf coords are 1-based, and our reference string index is 0-based
			[chrName,chrStart,chrEnd] = [splt[c1],int(splt[c2])-1,int(splt[c3])-1]
			[allele_ref,allele_normal,allele_tumor] = [splt[m1].upper(),splt[m2].upper(),splt[m3].upper()]
			if IS_VCF:
				if len(allele_ref) != len(allele_tumor):
					# indels in tsv don't include the preserved first nucleotide, so lets trim the vcf alleles
					[allele_ref,allele_normal,allele_tumor] = [allele_ref[1:],allele_normal[1:],allele_tumor[1:]]
				if not allele_ref: allele_ref = '-'
				if not allele_normal: allele_normal = '-'
				if not allele_tumor: allele_tumor = '-'
				# if alternate alleles are present, lets just ignore this variant. I may come back and improve this later
				if ',' in allele_tumor:
					continue
				vcf_info = ';'+splt[7]+';'
			else:
				[donor_id] = [splt[d_id]]

			# if we encounter a multi-np (i.e. 3 nucl --> 3 different nucl), let's skip it for now...
			if ('-' not in allele_normal and '-' not in allele_tumor) and (len(allele_normal) > 1 or len(allele_tumor) > 1):
				print 'skipping a complex variant...'
				continue

			# to deal with '1' vs 'chr1' references, manually change names. this is hacky and bad.
			if 'chr' not in chrName:
				chrName = 'chr'+chrName
			if 'chr' not in refName:
				refName = 'chr'+refName
			# skip irrelevant variants
			if chrName != refName:
				continue

			# if variant is outside the regions we're interested in (if specified), skip it...
			if MYBED != None:
				refKey = refName
				if not refKey in MYBED[0] and refKey[3:] in MYBED[0]:	# account for 1 vs chr1, again...
					refKey = refKey[3:]
				if refKey not in MYBED[0]:
					inBed = False
				else:
					inBed = isInBed(MYBED[0][refKey],chrStart)
				if inBed != MYBED[1]:
					continue

			# we want only snps
			# so, no '-' characters allowed, and chrStart must be same as chrEnd
			if '-' not in allele_normal and '-' not in allele_tumor and chrStart == chrEnd:
				trinuc_ref = refSequence[chrStart-1:chrStart+2]
				if not trinuc_ref in VALID_TRINUC:
					continue	# skip ref trinuc with invalid characters
				# only consider positions where ref allele in tsv matches the nucleotide in our reference
				if allele_ref == trinuc_ref[1]:
					trinuc_normal    = refSequence[chrStart-1] + allele_normal + refSequence[chrStart+1]
					trinuc_tumor     = refSequence[chrStart-1] + allele_tumor + refSequence[chrStart+1]
					if not trinuc_normal in VALID_TRINUC or not trinuc_tumor in VALID_TRINUC:
						continue	# skip if mutation contains invalid char
					key = (trinuc_normal,trinuc_tumor)
					if key not in TRINUC_TRANSITION_COUNT:
						TRINUC_TRANSITION_COUNT[key] = 0
					TRINUC_TRANSITION_COUNT[key] += 1
					SNP_COUNT += 1
					key2 = (allele_normal,allele_tumor)
					if key2 not in SNP_TRANSITION_COUNT:
						SNP_TRANSITION_COUNT[key2] = 0
					SNP_TRANSITION_COUNT[key2] += 1

					if IS_VCF:
						myPopFreq = VCF_DEFAULT_POP_FREQ
						if ';CAF=' in vcf_info:
							cafStr = re.findall(r";CAF=.*?(?=;)",vcf_info)[0]
							if ',' in cafStr:
								myPopFreq = float(cafStr[5:].split(',')[1])
						VDAT_COMMON.append((chrStart,allele_ref,allele_normal,allele_tumor,myPopFreq))
					else:
						VDAT_COMMON.append((chrStart,allele_ref,allele_normal,allele_tumor))
						TOTAL_DONORS[donor_id] = True
				else:
					print '\nError: ref allele in variant call does not match reference.\n'
					exit(1)

			# now let's look for indels...
			if '-' in allele_normal: len_normal = 0
			else: len_normal = len(allele_normal)
			if '-' in allele_tumor: len_tumor = 0
			else: len_tumor = len(allele_tumor)
			if len_normal != len_tumor:
				indel_len = len_tumor - len_normal
				if indel_len not in INDEL_COUNT:
					INDEL_COUNT[indel_len] = 0
				INDEL_COUNT[indel_len] += 1

				if IS_VCF:
					myPopFreq = VCF_DEFAULT_POP_FREQ
					if ';CAF=' in vcf_info:
						cafStr = re.findall(r";CAF=.*?(?=;)",vcf_info)[0]
						if ',' in cafStr:
							myPopFreq = float(cafStr[5:].split(',')[1])
					VDAT_COMMON.append((chrStart,allele_ref,allele_normal,allele_tumor,myPopFreq))
				else:
					VDAT_COMMON.append((chrStart,allele_ref,allele_normal,allele_tumor))
					TOTAL_DONORS[donor_id] = True
		f.close()

		# if we didn't find anything, skip ahead along to the next reference sequence
		if not len(VDAT_COMMON):
			print 'Found no variants for this reference, moving along...'
			continue

		#
		# identify common mutations
		#
		percentile_var = 95
		if IS_VCF:
			minVal = np.percentile([n[4] for n in VDAT_COMMON],percentile_var)
			for k in sorted(VDAT_COMMON):
				if k[4] >= minVal:
					COMMON_VARIANTS.append((refName,k[0],k[1],k[3],k[4]))
			VDAT_COMMON = {(n[0],n[1],n[2],n[3]):n[4] for n in VDAT_COMMON}
		else:
			N_DONORS = len(TOTAL_DONORS)
			VDAT_COMMON = list_2_countDict(VDAT_COMMON)
			minVal = int(np.percentile(VDAT_COMMON.values(),percentile_var))
			for k in sorted(VDAT_COMMON.keys()):
				if VDAT_COMMON[k] >= minVal:
					COMMON_VARIANTS.append((refName,k[0],k[1],k[3],VDAT_COMMON[k]/float(N_DONORS)))

		#
		# identify areas that have contained significantly higher random mutation rates
		#
		dist_thresh      = 2000
		percentile_clust = 97
		qptn             = 1000
		# identify regions with disproportionately more variants in them
		VARIANT_POS = sorted([n[0] for n in VDAT_COMMON.keys()])
		clustered_pos = clusterList(VARIANT_POS,dist_thresh)
		byLen  = [(len(clustered_pos[i]),min(clustered_pos[i]),max(clustered_pos[i]),i) for i in xrange(len(clustered_pos))]
		#byLen  = sorted(byLen,reverse=True)
		#minLen = int(np.percentile([n[0] for n in byLen],percentile_clust))
		#byLen  = [n for n in byLen if n[0] >= minLen]
		candidate_regions = []
		for n in byLen:
			bi = int((n[1]-dist_thresh)/float(qptn))*qptn
			bf = int((n[2]+dist_thresh)/float(qptn))*qptn
			candidate_regions.append((n[0]/float(bf-bi),max([0,bi]),min([len(refSequence),bf])))
		minVal = np.percentile([n[0] for n in candidate_regions],percentile_clust)
		for n in candidate_regions:
			if n[0] >= minVal:
				HIGH_MUT_REGIONS.append((refName,n[1],n[2],n[0]))
		# collapse overlapping regions
		for i in xrange(len(HIGH_MUT_REGIONS)-1,0,-1):
			if HIGH_MUT_REGIONS[i-1][2] >= HIGH_MUT_REGIONS[i][1] and HIGH_MUT_REGIONS[i-1][0] == HIGH_MUT_REGIONS[i][0]:
				avgMutRate = 0.5*HIGH_MUT_REGIONS[i-1][3]+0.5*HIGH_MUT_REGIONS[i][3]	# not accurate, but I'm lazy
				HIGH_MUT_REGIONS[i-1] = (HIGH_MUT_REGIONS[i-1][0], HIGH_MUT_REGIONS[i-1][1], HIGH_MUT_REGIONS[i][2], avgMutRate)
				del HIGH_MUT_REGIONS[i]

	#
	# if we didn't count ref trinucs because we found file, read in ref counts from file now
	#
	if os.path.isfile(REF+'.trinucCounts'):
		print 'reading pre-computed trinuc counts...'
		f = open(REF+'.trinucCounts','r')
		for line in f:
			splt = line.strip().split('\t')
			TRINUC_REF_COUNT[splt[0]] = int(splt[1])
		f.close()
	# otherwise, save trinuc counts to file, if desired
	elif SAVE_TRINUC:
		if MYBED != None:
			print 'unable to save trinuc counts to file because using input bed region...'
		else:
			print 'saving trinuc counts to file...'
			f = open(REF+'.trinucCounts','w')
			for trinuc in sorted(TRINUC_REF_COUNT.keys()):
				f.write(trinuc+'\t'+str(TRINUC_REF_COUNT[trinuc])+'\n')
			f.close()

	#
	# if using an input bed region, make necessary adjustments to trinuc ref counts based on the bed region trinuc counts
	#
	if MYBED != None:
		if MYBED[1] == True:	# we are restricting our attention to bed regions, so ONLY use bed region trinuc counts
			TRINUC_REF_COUNT = TRINUC_BED_COUNT
		else:					# we are only looking outside bed regions, so subtract bed region trinucs from entire reference trinucs
			for k in TRINUC_REF_COUNT.keys():
				if k in TRINUC_BED_COUNT:
					TRINUC_REF_COUNT[k] -= TRINUC_BED_COUNT[k]


	""" ##########################################################################
	###							COMPUTE PROBABILITIES						   ###
	########################################################################## """


	#for k in sorted(TRINUC_REF_COUNT.keys()):
	#		print k, TRINUC_REF_COUNT[k]
	#
	#for k in sorted(TRINUC_TRANSITION_COUNT.keys()):
	#	print k, TRINUC_TRANSITION_COUNT[k]

	# frequency that each trinuc mutated into anything else
	TRINUC_MUT_PROB = {}
	# frequency that a trinuc mutates into another trinuc, given that it mutated
	TRINUC_TRANS_PROBS = {}
	# frequency of snp transitions, given a snp occurs.
	SNP_TRANS_FREQ = {}

	for trinuc in sorted(TRINUC_REF_COUNT.keys()):
		myCount = 0
		for k in sorted(TRINUC_TRANSITION_COUNT.keys()):
			if k[0] == trinuc:
				myCount += TRINUC_TRANSITION_COUNT[k]
		TRINUC_MUT_PROB[trinuc] = myCount / float(TRINUC_REF_COUNT[trinuc])
		for k in sorted(TRINUC_TRANSITION_COUNT.keys()):
			if k[0] == trinuc:
				TRINUC_TRANS_PROBS[k] = TRINUC_TRANSITION_COUNT[k] / float(myCount)

	for n1 in VALID_NUCL:
		rollingTot = sum([SNP_TRANSITION_COUNT[(n1,n2)] for n2 in VALID_NUCL if (n1,n2) in SNP_TRANSITION_COUNT])
		for n2 in VALID_NUCL:
			key2 = (n1,n2)
			if key2 in SNP_TRANSITION_COUNT:
				SNP_TRANS_FREQ[key2] = SNP_TRANSITION_COUNT[key2] / float(rollingTot)

	# compute average snp and indel frequencies
	totalVar       = SNP_COUNT + sum(INDEL_COUNT.values())
	SNP_FREQ       = SNP_COUNT/float(totalVar)
	AVG_INDEL_FREQ = 1.-SNP_FREQ
	INDEL_FREQ     = {k:(INDEL_COUNT[k]/float(totalVar))/AVG_INDEL_FREQ for k in INDEL_COUNT.keys()}
	if MYBED != None:
		if MYBED[1] == True:
			AVG_MUT_RATE = totalVar/float(getTrackLen(MYBED[0]))
		else:
			AVG_MUT_RATE = totalVar/float(TOTAL_REFLEN - getTrackLen(MYBED[0]))
	else:
		AVG_MUT_RATE = totalVar/float(TOTAL_REFLEN)

	#
	#	print some stuff
	#
	for k in sorted(TRINUC_MUT_PROB.keys()):
		print 'p('+k+' mutates) =',TRINUC_MUT_PROB[k]

	for k in sorted(TRINUC_TRANS_PROBS.keys()):
		print 'p('+k[0]+' --> '+k[1]+' | '+k[0]+' mutates) =',TRINUC_TRANS_PROBS[k]

	for k in sorted(INDEL_FREQ.keys()):
		if k > 0:
			print 'p(ins length = '+str(abs(k))+' | indel occurs) =',INDEL_FREQ[k]
		else:
			print 'p(del length = '+str(abs(k))+' | indel occurs) =',INDEL_FREQ[k]

	for k in sorted(SNP_TRANS_FREQ.keys()):
		print 'p('+k[0]+' --> '+k[1]+' | SNP occurs) =',SNP_TRANS_FREQ[k]

	#for n in COMMON_VARIANTS:
	#	print n

	#for n in HIGH_MUT_REGIONS:
	#	print n

	print 'p(snp)   =',SNP_FREQ
	print 'p(indel) =',AVG_INDEL_FREQ
	print 'overall average mut rate:',AVG_MUT_RATE
	print 'total variants processed:',totalVar

	#
	# save variables to file
	#
	OUT_DICT = {'AVG_MUT_RATE':AVG_MUT_RATE,
	            'SNP_FREQ':SNP_FREQ,
	            'SNP_TRANS_FREQ':SNP_TRANS_FREQ,
	            'INDEL_FREQ':INDEL_FREQ,
	            'TRINUC_MUT_PROB':TRINUC_MUT_PROB,
	            'TRINUC_TRANS_PROBS':TRINUC_TRANS_PROBS,
	            'COMMON_VARIANTS':COMMON_VARIANTS,
	            'HIGH_MUT_REGIONS':HIGH_MUT_REGIONS}
	pickle.dump( OUT_DICT, open( OUT_PICKLE, "wb" ) )
Example #4
0
def main():

	# index reference
	refIndex = indexRef(REFERENCE)
	if PAIRED_END:
		N_HANDLING = ('random',FRAGMENT_SIZE)
	else:
		N_HANDLING = ('ignore',READLEN)
	indices_by_refName = {refIndex[n][0]:n for n in xrange(len(refIndex))}

	# parse input variants, if present
	inputVariants = []
	if INPUT_VCF != None:
		if CANCER:
			(sampNames, inputVariants) = parseVCF(INPUT_VCF,tumorNormal=True,ploidy=PLOIDS)
			tumorInd  = sampNames.index('TUMOR')
			normalInd = sampNames.index('NORMAL')
		else:
			(sampNames, inputVariants) = parseVCF(INPUT_VCF,ploidy=PLOIDS)
		for k in sorted(inputVariants.keys()):
			inputVariants[k].sort()

	# parse input targeted regions, if present
	inputRegions = {}
	if INPUT_BED != None:
		with open(INPUT_BED,'r') as f:
			for line in f:
				[myChr,pos1,pos2] = line.strip().split('\t')[:3]
				if myChr not in inputRegions:
					inputRegions[myChr] = [-1]
				inputRegions[myChr].extend([int(pos1),int(pos2)])

	# parse input mutation rate rescaling regions, if present
	mutRateRegions = {}
	mutRateValues  = {}
	if MUT_BED != None:
		with open(MUT_BED,'r') as f:
			for line in f:
				[myChr,pos1,pos2,metaData] = line.strip().split('\t')[:4]
				mutStr = re.findall(r"MUT_RATE=.*?(?=;)",metaData+';')
				(pos1,pos2) = (int(pos1),int(pos2))
				if len(mutStr) and (pos2-pos1) > 1:
					# mutRate = #_mutations / length_of_region, let's bound it by a reasonable amount
					mutRate = max([0.0,min([float(mutStr[0][9:]),0.3])])
					if myChr not in inputRegions:
						mutRateRegions[myChr] = [-1]
						mutRateValues[myChr]  = [0.0]
					mutRateRegions[myChr].extend([pos1,pos2])
					mutRateValues.extend([mutRate*(pos2-pos1)]*2)

	# initialize output files (part I)
	bamHeader = None
	if SAVE_BAM:
		bamHeader = [copy.deepcopy(refIndex)]
	vcfHeader = None
	if SAVE_VCF:
		vcfHeader = [REFERENCE]
	
	# If processing jobs in parallel, precompute the independent regions that can be process separately
	if NJOBS > 1:
		parallelRegionList  = getAllRefRegions(REFERENCE,refIndex,N_HANDLING,saveOutput=SAVE_NON_N)
		(myRefs, myRegions) = partitionRefRegions(parallelRegionList,refIndex,MYJOB,NJOBS)
		if not len(myRegions):
			print 'This job id has no regions to process, exiting...'
			exit(1)
		for i in xrange(len(refIndex)-1,-1,-1):	# delete reference not used in our job
			if not refIndex[i][0] in myRefs:
				del refIndex[i]
		# if value of NJOBS is too high, let's change it to the maximum possible, to avoid output filename confusion
		corrected_nJobs = min([NJOBS,sum([len(n) for n in parallelRegionList.values()])])
	else:
		corrected_nJobs = 1

	# initialize output files (part II)
	if CANCER:
		OFW = OutputFileWriter(OUT_PREFIX+'_normal',paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT)
		OFW_CANCER = OutputFileWriter(OUT_PREFIX+'_tumor',paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT,jobTuple=(MYJOB,corrected_nJobs))
	else:
		OFW = OutputFileWriter(OUT_PREFIX,paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT,jobTuple=(MYJOB,corrected_nJobs))
	OUT_PREFIX_NAME = OUT_PREFIX.split('/')[-1]


	"""************************************************
	****        LET'S GET THIS PARTY STARTED...
	************************************************"""


	readNameCount = 1	# keep track of the number of reads we've sampled, for read-names

	for RI in xrange(len(refIndex)):

		# read in reference sequence and notate blocks of Ns
		(refSequence,N_regions) = readRef(REFERENCE,refIndex[RI],N_HANDLING)

		# if we're processing jobs in parallel only take the regions relevant for the current job
		if NJOBS > 1:
			for i in xrange(len(N_regions['non_N'])-1,-1,-1):
				if not (refIndex[RI][0],N_regions['non_N'][i][0],N_regions['non_N'][i][1]) in myRegions:
					del N_regions['non_N'][i]

		# count total bp we'll be spanning so we can get an idea of how far along we are (for printing progress indicators)
		total_bp_span   = sum([n[1]-n[0] for n in N_regions['non_N']])
		currentProgress = 0
		currentPercent  = 0

		# prune invalid input variants, e.g variants that:
		#		- try to delete or alter any N characters
		#		- don't match the reference base at their specified position
		#		- any alt allele contains anything other than allowed characters
		validVariants = []
		nSkipped = [0,0,0]
		if refIndex[RI][0] in inputVariants:
			for n in inputVariants[refIndex[RI][0]]:
				span = (n[0],n[0]+len(n[1]))
				rseq = str(refSequence[span[0]-1:span[1]-1])	# -1 because going from VCF coords to array coords
				anyBadChr = any((nn not in ALLOWED_NUCL) for nn in [item for sublist in n[2] for item in sublist])
				if rseq != n[1]:
					nSkipped[0] += 1
					continue
				elif 'N' in rseq:
					nSkipped[1] += 1
					continue
				elif anyBadChr:
					nSkipped[2] += 1
					continue
				#if bisect.bisect(N_regions['big'],span[0])%2 or bisect.bisect(N_regions['big'],span[1])%2:
				#	continue
				validVariants.append(n)
			print 'found',len(validVariants),'valid variants for '+refIndex[RI][0]+' in input VCF...'
			if any(nSkipped):
				print sum(nSkipped),'variants skipped...'
				print ' - ['+str(nSkipped[0])+'] ref allele does not match reference'
				print ' - ['+str(nSkipped[1])+'] attempting to insert into N-region'
				print ' - ['+str(nSkipped[2])+'] alt allele contains non-ACGT characters'

		# add large random structural variants
		#
		#	TBD!!!

		# determine which structural variants will affect our sampling window positions
		structuralVars = []
		for n in validVariants:
			bufferNeeded = max([max([len(n[1])-len(alt_allele),1]) for alt_allele in n[2]])
			structuralVars.append((n[0]-1,bufferNeeded))	# -1 because going from VCF coords to array coords

		# determine sampling windows based on read length, large N regions, and structural mutations.
		# in order to obtain uniform coverage, windows should overlap by:
		#		- READLEN, if single-end reads
		#		- FRAGMENT_SIZE (mean), if paired-end reads
		# ploidy is fixed per large sampling window,
		# coverage distributions due to GC% and targeted regions are specified within these windows
		samplingWindows  = []
		ALL_VARIANTS_OUT = {}
		sequences        = None
		if PAIRED_END:
			targSize = WINDOW_TARGET_SCALE*FRAGMENT_SIZE
			overlap  = FRAGMENT_SIZE
		else:
			targSize = WINDOW_TARGET_SCALE*READLEN
			overlap  = READLEN

		print '--------------------------------'
		print 'sampling reads...'
		for i in xrange(len(N_regions['non_N'])):
			(pi,pf) = N_regions['non_N'][i]
			nTargWindows = max([1,(pf-pi)/targSize])
			bpd = int((pf-pi)/float(nTargWindows))
			bpd += GC_WINDOW_SIZE - bpd%GC_WINDOW_SIZE

			#print len(refSequence), (pi,pf), nTargWindows
			#print structuralVars

			# if for some reason our region is too small to process, skip it! (sorry)
			if nTargWindows == 1 and (pf-pi) < overlap-1:
				#print 'Does this ever happen?'
				continue

			start = pi
			end   = min([start+bpd,pf])
			#print '------------------RAWR:', (pi,pf), bpd
			currentVariantInd = 0
			varsFromPrevOverlap = []
			varsCancerFromPrevOverlap = []
			vindFromPrev = 0
			isLastTime = False
			while True:
				####print (start,end)
				# adjust end-position of window based on inserted structural mutations
				relevantVars = []
				if len(structuralVars) and currentVariantInd < len(structuralVars):
					prevVarInd = currentVariantInd
					while structuralVars[currentVariantInd][0] <= end:
						delta = (end-1) - (structuralVars[currentVariantInd][0] + structuralVars[currentVariantInd][1])
						if delta <= 0:
							####print 'DELTA:', delta
							end -= (delta-1)
						currentVariantInd += 1
						if currentVariantInd == len(structuralVars):
							break
					relevantVars = structuralVars[prevVarInd:currentVariantInd]
				next_start = end-overlap
				next_end   = min([next_start+bpd,pf])
				if next_end-next_start < bpd:
					end = next_end
					isLastTime = True

				# print progress indicator
				#print 'PROCESSING WINDOW:',(start,end)
				currentProgress += end-start
				newPercent = int((currentProgress*100)/float(total_bp_span))
				if newPercent > currentPercent:
					sys.stdout.write(str(newPercent)+'% ')
					sys.stdout.flush()
					currentPercent = newPercent

				# which inserted variants are in this window?
				varsInWindow = []
				updated = False
				for j in xrange(vindFromPrev,len(validVariants)):
					vPos = validVariants[j][0]
					if vPos >= start and vPos < end:
						varsInWindow.append(tuple([vPos-1]+list(validVariants[j][1:])))	# vcf --> array coords
					if vPos >= end-overlap-1 and updated == False:
						updated = True
						vindFromPrev = j
					if vPos >= end:
						break

				# if computing only VCF, we can skip this...
				if ONLY_VCF:
					coverage_dat = None
					coverage_avg = None
				else:
					# pre-compute gc-bias and targeted sequencing coverage modifiers
					nSubWindows  = (end-start)/GC_WINDOW_SIZE
					coverage_dat = (GC_WINDOW_SIZE,[])
					for j in xrange(nSubWindows):
						rInd = start + j*GC_WINDOW_SIZE
						if INPUT_BED == None: tCov = True
						else: tCov = not(bisect.bisect(inputRegions[myChr],rInd)%2) or not(bisect.bisect(inputRegions[myChr],rInd+GC_WINDOW_SIZE)%2)
						if tCov: tScl = 1.0
						else: tScl = OFFTARGET_SCALAR
						gc_v = refSequence[rInd:rInd+GC_WINDOW_SIZE].count('G') + refSequence[rInd:rInd+GC_WINDOW_SIZE].count('C')
						gScl = GC_SCALE_VAL[gc_v]
						coverage_dat[1].append(1.0*tScl*gScl)
					coverage_avg = np.mean(coverage_dat[1])

				# pre-compute mutation rate tracks
				# PROVIDED MUTATION RATES OVERRIDE AVERAGE VALUE

				# construct sequence data that we will sample reads from
				if sequences == None:
					sequences = SequenceContainer(start,refSequence[start:end],PLOIDS,overlap,READLEN,[MUT_MODEL]*PLOIDS,MUT_RATE,coverage_dat,onlyVCF=ONLY_VCF)
				else:
					sequences.update(start,refSequence[start:end],PLOIDS,overlap,READLEN,[MUT_MODEL]*PLOIDS,MUT_RATE,coverage_dat)

				# adjust position of all inserted variants to match current window offset
				#variants_to_insert = []
				#for n in varsFromPrevOverlap:
				#	ln = [n[0]-start] + list(n[1:])
				#	variants_to_insert.append(tuple(ln))
				#for n in varsInWindow:
				#	ln = [n[0]-start] + list(n[1:])
				#	variants_to_insert.append(tuple(ln))
				#sequences.insert_mutations(variants_to_insert)
				sequences.insert_mutations(varsFromPrevOverlap + varsInWindow)
				all_inserted_variants = sequences.random_mutations()
				#print all_inserted_variants

				if CANCER:
					tumor_sequences = SequenceContainer(start,refSequence[start:end],PLOIDS,overlap,READLEN,[CANCER_MODEL]*PLOIDS,MUT_RATE,coverage_dat)
					tumor_sequences.insert_mutations(varsCancerFromPrevOverlap + all_inserted_variants)
					all_cancer_variants = tumor_sequences.random_mutations()

				# which variants do we need to keep for next time (because of window overlap)?
				varsFromPrevOverlap       = []
				varsCancerFromPrevOverlap = []
				for n in all_inserted_variants:
					if n[0] >= end-overlap-1:
						varsFromPrevOverlap.append(n)
				if CANCER:
					for n in all_cancer_variants:
						if n[0] >= end-overlap-1:
							varsCancerFromPrevOverlap.append(n)

				# if we're only producing VCF, no need to go through the hassle of generating reads
				if ONLY_VCF:
					pass
				else:
					# for each sampling window, construct sub-windows with coverage information
					covWindows = [COVERAGE for n in xrange((end-start)/SMALL_WINDOW)]
					if (end-start)%SMALL_WINDOW:
						covWindows.append(COVERAGE)
					meanCov = sum(covWindows)/float(len(covWindows))
					if PAIRED_END:
						readsToSample = int(((end-start)*meanCov*coverage_avg)/(2*READLEN))+1
					else:
						readsToSample = int(((end-start)*meanCov*coverage_avg)/(READLEN))+1

					# sample reads from altered reference
					for i in xrange(readsToSample):

						if PAIRED_END:
							myFraglen = FRAGLEN_DISTRIBUTION.sample()
							myReadData = sequences.sample_read(SE_CLASS,myFraglen)
							myReadData[0][0] += start	# adjust mapping position based on window start
							myReadData[1][0] += start
						else:
							myReadData = sequences.sample_read(SE_CLASS)
							myReadData[0][0] += start	# adjust mapping position based on window start
					
						if NJOBS > 1:
							myReadName = OUT_PREFIX_NAME+'-j'+str(MYJOB)+'-'+refIndex[RI][0]+'-r'+str(readNameCount)
						else:
							myReadName = OUT_PREFIX_NAME+'-'+refIndex[RI][0]+'-'+str(readNameCount)
						readNameCount += len(myReadData)

						# if desired, replace all low-quality bases with Ns
						if N_MAX_QUAL > -1:
							for j in xrange(len(myReadData)):
								myReadString = [n for n in myReadData[j][2]]
								for k in xrange(len(myReadData[j][3])):
									adjusted_qual = ord(myReadData[j][3][k])-SE_CLASS.offQ
									if adjusted_qual <= N_MAX_QUAL:
										myReadString[k] = 'N'
								myReadData[j][2] = ''.join(myReadString)

						# write read data out to FASTQ and BAM files, bypass FASTQ if option specified
						myRefIndex = indices_by_refName[refIndex[RI][0]]
						if len(myReadData) == 1:
							if NO_FASTQ != True:
								OFW.writeFASTQRecord(myReadName,myReadData[0][2],myReadData[0][3])
							if SAVE_BAM:
								OFW.writeBAMRecord(myRefIndex, myReadName+'/1', myReadData[0][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=0)
						elif len(myReadData) == 2:
							if NO_FASTQ != True:
								OFW.writeFASTQRecord(myReadName,myReadData[0][2],myReadData[0][3],read2=myReadData[1][2],qual2=myReadData[1][3])
							if SAVE_BAM:
								OFW.writeBAMRecord(myRefIndex, myReadName+'/1', myReadData[0][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=99,  matePos=myReadData[1][0])
								OFW.writeBAMRecord(myRefIndex, myReadName+'/2', myReadData[1][0], myReadData[1][1], myReadData[1][2], myReadData[1][3], samFlag=147, matePos=myReadData[0][0])
						else:
							print '\nError: Unexpected number of reads generated...\n'
							exit(1)

				# tally up all the variants that got successfully introduced
				for n in all_inserted_variants:
					ALL_VARIANTS_OUT[n] = True

				# prepare indices of next window
				start = next_start
				end   = next_end
				if isLastTime:
					break
				if end >= pf:
					isLastTime = True

		if currentPercent != 100:
			print '100%'
		else:
			print ''

		# write all output variants for this reference
		if SAVE_VCF:
			for k in sorted(ALL_VARIANTS_OUT.keys()):
				currentRef = refIndex[RI][0]
				myID       = '.'
				myQual     = '.'
				myFilt     = 'PASS'
				# k[0] + 1 because we're going back to 1-based vcf coords
				OFW.writeVCFRecord(currentRef, str(int(k[0])+1), myID, k[1], k[2], myQual, myFilt, k[4])

		#break

	# close output files
	OFW.closeFiles()
	if CANCER:
		OFW_CANCER.closeFiles()
Example #5
0
def main():

	# index reference
	refIndex = indexRef(REFERENCE)
	if PAIRED_END:
		N_HANDLING = ('random',FRAGMENT_SIZE)
	else:
		N_HANDLING = ('ignore',READLEN)
	indices_by_refName = {refIndex[n][0]:n for n in xrange(len(refIndex))}

	# parse input variants, if present
	inputVariants = []
	if INPUT_VCF != None:
		if CANCER:
			(sampNames, inputVariants) = parseVCF(INPUT_VCF,tumorNormal=True,ploidy=PLOIDS)
			tumorInd  = sampNames.index('TUMOR')
			normalInd = sampNames.index('NORMAL')
		else:
			(sampNames, inputVariants) = parseVCF(INPUT_VCF,ploidy=PLOIDS)
		for k in sorted(inputVariants.keys()):
			inputVariants[k].sort()

	# parse input targeted regions, if present
	inputRegions = {}
	if INPUT_BED != None:
		with open(INPUT_BED,'r') as f:
			for line in f:
				[myChr,pos1,pos2] = line.strip().split('\t')[:3]
				if myChr not in inputRegions:
					inputRegions[myChr] = [-1]
				inputRegions[myChr].extend([int(pos1),int(pos2)])

	# parse input mutation rate rescaling regions, if present
	mutRateRegions = {}
	mutRateValues  = {}
	if MUT_BED != None:
		with open(MUT_BED,'r') as f:
			for line in f:
				[myChr,pos1,pos2,metaData] = line.strip().split('\t')[:4]
				mutStr = re.findall(r"MUT_RATE=.*?(?=;)",metaData+';')
				(pos1,pos2) = (int(pos1),int(pos2))
				if len(mutStr) and (pos2-pos1) > 1:
					# mutRate = #_mutations / length_of_region, let's bound it by a reasonable amount
					mutRate = max([0.0,min([float(mutStr[0][9:]),0.3])])
					if myChr not in mutRateRegions:
						mutRateRegions[myChr] = [-1]
						mutRateValues[myChr]  = [0.0]
					mutRateRegions[myChr].extend([pos1,pos2])
					mutRateValues.extend([mutRate*(pos2-pos1)]*2)

	# initialize output files (part I)
	bamHeader = None
	if SAVE_BAM:
		bamHeader = [copy.deepcopy(refIndex)]
	vcfHeader = None
	if SAVE_VCF:
		vcfHeader = [REFERENCE]
	
	# If processing jobs in parallel, precompute the independent regions that can be process separately
	if NJOBS > 1:
		parallelRegionList  = getAllRefRegions(REFERENCE,refIndex,N_HANDLING,saveOutput=SAVE_NON_N)
		(myRefs, myRegions) = partitionRefRegions(parallelRegionList,refIndex,MYJOB,NJOBS)
		if not len(myRegions):
			print 'This job id has no regions to process, exiting...'
			exit(1)
		for i in xrange(len(refIndex)-1,-1,-1):	# delete reference not used in our job
			if not refIndex[i][0] in myRefs:
				del refIndex[i]
		# if value of NJOBS is too high, let's change it to the maximum possible, to avoid output filename confusion
		corrected_nJobs = min([NJOBS,sum([len(n) for n in parallelRegionList.values()])])
	else:
		corrected_nJobs = 1

	# initialize output files (part II)
	if CANCER:
		OFW = OutputFileWriter(OUT_PREFIX+'_normal',paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT)
		OFW_CANCER = OutputFileWriter(OUT_PREFIX+'_tumor',paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT,jobTuple=(MYJOB,corrected_nJobs))
	else:
		OFW = OutputFileWriter(OUT_PREFIX,paired=PAIRED_END,BAM_header=bamHeader,VCF_header=vcfHeader,gzipped=GZIPPED_OUT,jobTuple=(MYJOB,corrected_nJobs))
	OUT_PREFIX_NAME = OUT_PREFIX.split('/')[-1]


	"""************************************************
	****        LET'S GET THIS PARTY STARTED...
	************************************************"""


	readNameCount = 1	# keep track of the number of reads we've sampled, for read-names

	for RI in xrange(len(refIndex)):

		# read in reference sequence and notate blocks of Ns
		(refSequence,N_regions) = readRef(REFERENCE,refIndex[RI],N_HANDLING)

		# if we're processing jobs in parallel only take the regions relevant for the current job
		if NJOBS > 1:
			for i in xrange(len(N_regions['non_N'])-1,-1,-1):
				if not (refIndex[RI][0],N_regions['non_N'][i][0],N_regions['non_N'][i][1]) in myRegions:
					del N_regions['non_N'][i]

		# count total bp we'll be spanning so we can get an idea of how far along we are (for printing progress indicators)
		total_bp_span   = sum([n[1]-n[0] for n in N_regions['non_N']])
		currentProgress = 0
		currentPercent  = 0

		# prune invalid input variants, e.g variants that:
		#		- try to delete or alter any N characters
		#		- don't match the reference base at their specified position
		#		- any alt allele contains anything other than allowed characters
		validVariants = []
		nSkipped = [0,0,0]
		if refIndex[RI][0] in inputVariants:
			for n in inputVariants[refIndex[RI][0]]:
				span = (n[0],n[0]+len(n[1]))
				rseq = str(refSequence[span[0]-1:span[1]-1])	# -1 because going from VCF coords to array coords
				anyBadChr = any((nn not in ALLOWED_NUCL) for nn in [item for sublist in n[2] for item in sublist])
				if rseq != n[1]:
					nSkipped[0] += 1
					continue
				elif 'N' in rseq:
					nSkipped[1] += 1
					continue
				elif anyBadChr:
					nSkipped[2] += 1
					continue
				#if bisect.bisect(N_regions['big'],span[0])%2 or bisect.bisect(N_regions['big'],span[1])%2:
				#	continue
				validVariants.append(n)
			print 'found',len(validVariants),'valid variants for '+refIndex[RI][0]+' in input VCF...'
			if any(nSkipped):
				print sum(nSkipped),'variants skipped...'
				print ' - ['+str(nSkipped[0])+'] ref allele does not match reference'
				print ' - ['+str(nSkipped[1])+'] attempting to insert into N-region'
				print ' - ['+str(nSkipped[2])+'] alt allele contains non-ACGT characters'

		# add large random structural variants
		#
		#	TBD!!!

		# determine which structural variants will affect our sampling window positions
		structuralVars = []
		for n in validVariants:
			bufferNeeded = max([max([len(n[1])-len(alt_allele),1]) for alt_allele in n[2]])
			structuralVars.append((n[0]-1,bufferNeeded))	# -1 because going from VCF coords to array coords

		# determine sampling windows based on read length, large N regions, and structural mutations.
		# in order to obtain uniform coverage, windows should overlap by:
		#		- READLEN, if single-end reads
		#		- FRAGMENT_SIZE (mean), if paired-end reads
		# ploidy is fixed per large sampling window,
		# coverage distributions due to GC% and targeted regions are specified within these windows
		samplingWindows  = []
		ALL_VARIANTS_OUT = {}
		sequences        = None
		if PAIRED_END:
			targSize = WINDOW_TARGET_SCALE*FRAGMENT_SIZE
			overlap  = FRAGMENT_SIZE
		else:
			targSize = WINDOW_TARGET_SCALE*READLEN
			overlap  = READLEN

		print '--------------------------------'
		print 'sampling reads...'
		for i in xrange(len(N_regions['non_N'])):
			(pi,pf) = N_regions['non_N'][i]
			nTargWindows = max([1,(pf-pi)/targSize])
			bpd = int((pf-pi)/float(nTargWindows))
			bpd += GC_WINDOW_SIZE - bpd%GC_WINDOW_SIZE

			#print len(refSequence), (pi,pf), nTargWindows
			#print structuralVars

			# if for some reason our region is too small to process, skip it! (sorry)
			if nTargWindows == 1 and (pf-pi) < overlap-1:
				#print 'Does this ever happen?'
				continue

			start = pi
			end   = min([start+bpd,pf])
			#print '------------------RAWR:', (pi,pf), bpd
			currentVariantInd = 0
			varsFromPrevOverlap = []
			varsCancerFromPrevOverlap = []
			vindFromPrev = 0
			isLastTime = False
			while True:
				####print (start,end)
				# adjust end-position of window based on inserted structural mutations
				relevantVars = []
				if len(structuralVars) and currentVariantInd < len(structuralVars):
					prevVarInd = currentVariantInd
					while structuralVars[currentVariantInd][0] <= end:
						delta = (end-1) - (structuralVars[currentVariantInd][0] + structuralVars[currentVariantInd][1])
						if delta <= 0:
							####print 'DELTA:', delta
							end -= (delta-1)
						currentVariantInd += 1
						if currentVariantInd == len(structuralVars):
							break
					relevantVars = structuralVars[prevVarInd:currentVariantInd]
				next_start = end-overlap
				next_end   = min([next_start+bpd,pf])
				if next_end-next_start < bpd:
					end = next_end
					isLastTime = True

				# print progress indicator
				#print 'PROCESSING WINDOW:',(start,end)
				currentProgress += end-start
				newPercent = int((currentProgress*100)/float(total_bp_span))
				if newPercent > currentPercent:
					sys.stdout.write(str(newPercent)+'% ')
					sys.stdout.flush()
					currentPercent = newPercent

				# which inserted variants are in this window?
				varsInWindow = []
				updated = False
				for j in xrange(vindFromPrev,len(validVariants)):
					vPos = validVariants[j][0]
					if vPos >= start and vPos < end:
						varsInWindow.append(tuple([vPos-1]+list(validVariants[j][1:])))	# vcf --> array coords
					if vPos >= end-overlap-1 and updated == False:
						updated = True
						vindFromPrev = j
					if vPos >= end:
						break

				# if computing only VCF, we can skip this...
				if ONLY_VCF:
					coverage_dat = None
					coverage_avg = None
				else:
					# pre-compute gc-bias and targeted sequencing coverage modifiers
					nSubWindows  = (end-start)/GC_WINDOW_SIZE
					coverage_dat = (GC_WINDOW_SIZE,[])
					for j in xrange(nSubWindows):
						rInd = start + j*GC_WINDOW_SIZE
						if INPUT_BED == None: tCov = True
						else: tCov = not(bisect.bisect(inputRegions[myChr],rInd)%2) or not(bisect.bisect(inputRegions[myChr],rInd+GC_WINDOW_SIZE)%2)
						if tCov: tScl = 1.0
						else: tScl = OFFTARGET_SCALAR
						gc_v = refSequence[rInd:rInd+GC_WINDOW_SIZE].count('G') + refSequence[rInd:rInd+GC_WINDOW_SIZE].count('C')
						gScl = GC_SCALE_VAL[gc_v]
						coverage_dat[1].append(1.0*tScl*gScl)
					coverage_avg = np.mean(coverage_dat[1])

				# pre-compute mutation rate tracks
				# PROVIDED MUTATION RATES OVERRIDE AVERAGE VALUE

				# construct sequence data that we will sample reads from
				if sequences == None:
					sequences = SequenceContainer(start,refSequence[start:end],PLOIDS,overlap,READLEN,[MUT_MODEL]*PLOIDS,MUT_RATE,coverage_dat,onlyVCF=ONLY_VCF)
				else:
					sequences.update(start,refSequence[start:end],PLOIDS,overlap,READLEN,[MUT_MODEL]*PLOIDS,MUT_RATE,coverage_dat)

				# adjust position of all inserted variants to match current window offset
				#variants_to_insert = []
				#for n in varsFromPrevOverlap:
				#	ln = [n[0]-start] + list(n[1:])
				#	variants_to_insert.append(tuple(ln))
				#for n in varsInWindow:
				#	ln = [n[0]-start] + list(n[1:])
				#	variants_to_insert.append(tuple(ln))
				#sequences.insert_mutations(variants_to_insert)
				sequences.insert_mutations(varsFromPrevOverlap + varsInWindow)
				all_inserted_variants = sequences.random_mutations()
				#print all_inserted_variants

				if CANCER:
					tumor_sequences = SequenceContainer(start,refSequence[start:end],PLOIDS,overlap,READLEN,[CANCER_MODEL]*PLOIDS,MUT_RATE,coverage_dat)
					tumor_sequences.insert_mutations(varsCancerFromPrevOverlap + all_inserted_variants)
					all_cancer_variants = tumor_sequences.random_mutations()

				# which variants do we need to keep for next time (because of window overlap)?
				varsFromPrevOverlap       = []
				varsCancerFromPrevOverlap = []
				for n in all_inserted_variants:
					if n[0] >= end-overlap-1:
						varsFromPrevOverlap.append(n)
				if CANCER:
					for n in all_cancer_variants:
						if n[0] >= end-overlap-1:
							varsCancerFromPrevOverlap.append(n)

				# if we're only producing VCF, no need to go through the hassle of generating reads
				if ONLY_VCF:
					pass
				else:
					# for each sampling window, construct sub-windows with coverage information
					covWindows = [COVERAGE for n in xrange((end-start)/SMALL_WINDOW)]
					if (end-start)%SMALL_WINDOW:
						covWindows.append(COVERAGE)
					meanCov = sum(covWindows)/float(len(covWindows))
					if PAIRED_END:
						readsToSample = int(((end-start)*meanCov*coverage_avg)/(2*READLEN))+1
					else:
						readsToSample = int(((end-start)*meanCov*coverage_avg)/(READLEN))+1

					# sample reads from altered reference
					for i in xrange(readsToSample):

						if PAIRED_END:
							myFraglen = FRAGLEN_DISTRIBUTION.sample()
							myReadData = sequences.sample_read(SE_CLASS,myFraglen)
							myReadData[0][0] += start	# adjust mapping position based on window start
							myReadData[1][0] += start
						else:
							myReadData = sequences.sample_read(SE_CLASS)
							myReadData[0][0] += start	# adjust mapping position based on window start
					
						if NJOBS > 1:
							myReadName = OUT_PREFIX_NAME+'-j'+str(MYJOB)+'-'+refIndex[RI][0]+'-r'+str(readNameCount)
						else:
							myReadName = OUT_PREFIX_NAME+'-'+refIndex[RI][0]+'-'+str(readNameCount)
						readNameCount += len(myReadData)

						# if desired, replace all low-quality bases with Ns
						if N_MAX_QUAL > -1:
							for j in xrange(len(myReadData)):
								myReadString = [n for n in myReadData[j][2]]
								for k in xrange(len(myReadData[j][3])):
									adjusted_qual = ord(myReadData[j][3][k])-SE_CLASS.offQ
									if adjusted_qual <= N_MAX_QUAL:
										myReadString[k] = 'N'
								myReadData[j][2] = ''.join(myReadString)

						# write read data out to FASTQ and BAM files, bypass FASTQ if option specified
						myRefIndex = indices_by_refName[refIndex[RI][0]]
						if len(myReadData) == 1:
							if NO_FASTQ != True:
								OFW.writeFASTQRecord(myReadName,myReadData[0][2],myReadData[0][3])
							if SAVE_BAM:
								OFW.writeBAMRecord(myRefIndex, myReadName+'/1', myReadData[0][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=0)
						elif len(myReadData) == 2:
							if NO_FASTQ != True:
								OFW.writeFASTQRecord(myReadName,myReadData[0][2],myReadData[0][3],read2=myReadData[1][2],qual2=myReadData[1][3])
							if SAVE_BAM:
								OFW.writeBAMRecord(myRefIndex, myReadName+'/1', myReadData[0][0], myReadData[0][1], myReadData[0][2], myReadData[0][3], samFlag=99,  matePos=myReadData[1][0])
								OFW.writeBAMRecord(myRefIndex, myReadName+'/2', myReadData[1][0], myReadData[1][1], myReadData[1][2], myReadData[1][3], samFlag=147, matePos=myReadData[0][0])
						else:
							print '\nError: Unexpected number of reads generated...\n'
							exit(1)

				# tally up all the variants that got successfully introduced
				for n in all_inserted_variants:
					ALL_VARIANTS_OUT[n] = True

				# prepare indices of next window
				start = next_start
				end   = next_end
				if isLastTime:
					break
				if end >= pf:
					isLastTime = True

		if currentPercent != 100:
			print '100%'
		else:
			print ''

		# write all output variants for this reference
		if SAVE_VCF:
			for k in sorted(ALL_VARIANTS_OUT.keys()):
				currentRef = refIndex[RI][0]
				myID       = '.'
				myQual     = '.'
				myFilt     = 'PASS'
				# k[0] + 1 because we're going back to 1-based vcf coords
				OFW.writeVCFRecord(currentRef, str(int(k[0])+1), myID, k[1], k[2], myQual, myFilt, k[4])

		#break

	# close output files
	OFW.closeFiles()
	if CANCER:
		OFW_CANCER.closeFiles()