def infer_vcf(args): stats = Counter() model = models.load_model(args.weights_hd5, custom_objects=models.get_all_custom_objects(args.labels)) vcf_reader = pysam.VariantFile(args.negative_vcf, 'r') vcf_writer = pysam.VariantFile(args.output_vcf, 'w', header=vcf_reader.header) print('got vcfs.') reference = SeqIO.to_dict(SeqIO.parse(args.reference_fasta, "fasta")) print('Loaded reference FASTA:', args.reference_fasta) samfile = pysam.AlignmentFile(args.bam_file, "rb") print('got sam.') if args.chrom: intervals = { args.chrom : [int(args.start_pos), int(args.end_pos)] } elif args.bed_file: intervals = td.bed_file_to_dict(args.bed_file) else: raise ValueError('What do you want to iterate over? Use arguments --bed_file or --chrom --start_pos --end_pos') tensor_batch = np.zeros((args.batch_size,)+defines.tensor_shape_from_args(args)) gpos_batch = [] print(len(intervals), 'intervals to iterate over, contigs:', intervals.keys()) start_time = time.time() for k in intervals: contig = reference[k] args.chrom = k for start,stop in zip(intervals[k][0], intervals[k][1]): cur_pos = start for cur_pos in range(start, stop, args.window_size): record = contig[cur_pos: cur_pos+args.window_size] t = td.make_calling_tensor(args, samfile, record, cur_pos, stats) if not t is None: tensor_batch[stats['cur_tensor']] = t gpos_batch.append((k, cur_pos, record)) stats['cur_tensor'] += 1 if stats['cur_tensor'] == args.batch_size: predictions = model.predict(tensor_batch) # predictions is a numpy arra predictions_to_variants(args, predictions, gpos_batch, tensor_batch, vcf_writer, record) tensor_batch = np.zeros((args.batch_size,)+defines.tensor_shape_from_args(args)) stats['cur_tensor'] = 0 stats['batches_processed'] += 1 gpos_batch = [] if stats['batches_processed'] % 100 == 0: elapsed = time.time() - start_time t_per_minute = stats['batches_processed']*args.batch_size / (elapsed/60) print('At genomic position:', k, cur_pos, 'Tensors per minute:', t_per_minute,'Batches processed:', stats['batches_processed']) for s in stats.keys(): print(s, 'has:', stats[s]) for s in stats.keys(): print(s, 'has:', stats[s])
def infer_tensor(args): stats = Counter() model = models.load_model(args.weights_hd5, custom_objects=models.get_all_custom_objects( args.labels)) vcf_reader = pysam.VariantFile(args.negative_vcf, 'r') vcf_writer = pysam.VariantFile(args.output_vcf, 'w', header=vcf_reader.header) print('got vcfs.') tensor_paths = [ args.data_dir + tp for tp in sorted(os.listdir(args.data_dir)) ] print('found tensors: ', len(tensor_paths)) tensor_batch = np.zeros((args.batch_size, ) + defines.tensor_shape_from_args(args)) gpos_batch = [] for tp in tensor_paths: with h5py.File(tp, 'r') as hf: tensor_batch[stats['cur_tensor']] = np.array( hf.get(args.tensor_map)) gpos_batch.append( td.position_string_from_tensor_name(tp).split('_')) stats['cur_tensor'] += 1 if stats['cur_tensor'] == args.batch_size: ## Evaluate the model predictions = model.predict( tensor_batch) # predictions is a numpy arra predictions_to_variants(args, predictions, gpos_batch, tensor_batch, vcf_writer) stats['cur_tensor'] = 0 gpos_batch = []
def annotate_vcf_with_inference(args): cnns = {} stats = Counter() vcf_reader = pysam.VariantFile(args.negative_vcf, 'rb') pyvcf_vcf_reader = vcf.Reader(open(args.negative_vcf, 'rb')) input_tensors = {} for a in args.architectures: cnns[a] = models.set_args_and_get_model_from_semantics(args, a) print('Annotating with architecture:', a, 'sample name is', args.sample_name) if not score_key_from_json(a) in vcf_reader.header.info: vcf_reader.header.info.add(score_key_from_json(a), '1', 'Float', 'Site-level score from Convolutional Neural Net named '+a+'.') if defines.annotations_from_args(args) is not None: input_tensors[args.annotation_set] = (len(args.annotations),) input_tensors[args.tensor_map] = defines.tensor_shape_from_args(args) vcf_writer = pysam.VariantFile(args.output_vcf, 'w', header=vcf_reader.header) print('got vcfs. input tensor shape mapping:', input_tensors) reference = SeqIO.to_dict(SeqIO.parse(args.reference_fasta, "fasta")) print('got ref.') samfile = pysam.AlignmentFile(args.bam_file, "rb") print('got sam.') positions = [] variant_batch = [] time_batch = time.time() batch = {} for tm in input_tensors: batch[tm] = np.zeros(((args.batch_size,) + input_tensors[tm])) print('input tensors:', input_tensors) if args.chrom: print('iterate over region of vcf', args.chrom, args.start_pos, args.end_pos) variants = vcf_reader.fetch(args.chrom, args.start_pos, args.end_pos) else: print('iterate over vcf') variants = vcf_reader start_time = time.time() for variant in variants: idx_offset, ref_start, ref_end = get_variant_window(args, variant) args.chrom = variant.contig # In case chrom isn't set on command line we need it to fetch reads. contig = reference[variant.contig] record = contig[ ref_start : ref_end ] v = pysam_variant_in_pyvcf(variant, pyvcf_vcf_reader) for tm in batch: batch_key = tm+'_in_batch' if tm in defines.annotations: args.annotation_set = tm annotation_data = td.get_annotation_data(args, v, stats) batch[tm][stats[batch_key]] = annotation_data stats[batch_key] += 1 if 'read' in tm: args.tensor_map = tm if "read_tensor" == args.tensor_map: read_tensor = td.make_reference_and_reads_tensor(args, v, samfile, record.seq, ref_start, stats) elif "paired_reads" == args.tensor_map: read_tensor = td.make_paired_read_tensor(args, v, samfile, record.seq, ref_start, ref_end, stats) else: raise ValueError("Unknown read tensor mapping."+tt) batch[tm][stats[batch_key]] = read_tensor if read_tensor is None: print('got empty', args.tensor_map, 'tensor at:', v) batch[tm][stats[batch_key]] = np.zeros(input_tensors[tm]) stats[batch_key] += 1 if 'reference' in tm: args.tensor_map = tm reference_tensor = td.make_reference_tensor(args, record.seq) batch[tm][stats[batch_key]] = reference_tensor stats[batch_key] += 1 positions.append(variant.contig + '_' + str(variant.pos)) variant_batch.append(variant) if stats[batch_key] == args.batch_size: apply_cnns_to_batch(args, cnns, batch, positions, variant_batch, vcf_writer, stats) # Reset the batch positions = [] variant_batch = [] for tm in batch: batch_key = tm+'_in_batch' batch[tm] = np.zeros(((args.batch_size,) + input_tensors[tm])) stats[batch_key] = 0 stats['batches processed'] += 1 if stats['batches processed'] % 10 == 0: elapsed = time.time()-start_time v_per_minute = stats['batches processed']*args.batch_size / (elapsed/60) print('Variants per minute:', v_per_minute, 'Batches:', stats['batches processed'], 'batches. Last variant:', variant) if stats['batches processed']*args.batch_size > args.samples: break if stats[batch_key] > 0: apply_cnns_to_batch(args, cnns, batch, positions, variant_batch, vcf_writer, stats) for s in stats.keys(): print(s, 'has:', stats[s])