def test_read_table_columns(mocker):
    table = StringIO('a,b,c,d\n' 'e,f,g,h\n' 'i,j,k,l\n' 'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'a': list('eim'),
        'b': list('fjn'),
        'c': list('gko'),
        'd': list('hlp')
    }
    assert labels == list('abcd')

    # fewer headers than rest
    table = StringIO('a,b,c\ne,f,g,h\ni,j,k,l\nm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'a': list('eim'),
        'b': list('fjn'),
        'c': list('gko'),
    }
    assert labels == list('abc')

    # fewer columns than headers
    table = StringIO('a,b,c,d\ne,f,g\ni,j,k\nm,n,o\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {'a': list('eim'), 'b': list('fjn'), 'c': list('gko'), 'd': []}
    assert labels == list('abcd')
Esempio n. 2
0
def get_range_seqs(strains, chrm, start, end, tag, gp_dir='../'):
    # TODO this shouldn't actually be dependent on tag

    strain_range_seqs = {}
    for strain, d in strains:
        print(strain)
        fn = d + strain + '_chr' + chrm + gp.fasta_suffix
        chrm_seq = read_fasta.read_fasta(fn)[1][0]

        t = None
        try:
            t, labels = read_table.read_table_columns(
                gp.analysis_out_dir_absolute + tag + '/' +
                'site_summaries/predictions_' + strain + '_chr' + chrm +
                '_site_summary.txt.gz', '\t')
        except FileNotFoundError:
            # for par reference which doesn't have site summary file
            align_fn = gp_dir + gp.alignments_dir + \
                       '_'.join(gp.alignment_ref_order) + '_chr' + chrm + \
                       '_mafft' + gp.alignment_suffix
            t = get_inds_from_alignment(align_fn, True)

        ref_ind_to_strain_ind = dict(zip(t['ps_ref'], t['ps_strain']))

        start_strain = int(math.ceil(float(ref_ind_to_strain_ind[str(start)])))
        end_strain = int(math.floor(float(ref_ind_to_strain_ind[str(end)])))

        strain_range_seqs[strain] = (chrm_seq[start_strain:end_strain + 1],
                                     start_strain, end_strain)
    return strain_range_seqs
def test_read_table_columns_empty(mocker):
    table = StringIO('')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    mocked_gz = mocker.patch('misc.read_table.gzip.open', return_value=table)

    d, labels = read_table.read_table_columns('mocked', '\t')
    mocked_file.assert_called_with('mocked', 'r')
    mocked_gz.assert_not_called()
    assert d == {'': []}
    assert labels == ['']

    table = StringIO('')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    mocked_gz = mocker.patch('misc.read_table.gzip.open', return_value=table)

    def return_args(arg):
        return arg

    d, labels = read_table.read_table_columns('mocked.gz', '\t')
    mocked_gz.assert_called_with('mocked.gz', 'rt')
    mocked_file.assert_not_called()
    assert d == {'': []}
    assert labels == ['']
    def __init__(self,
                 labeled_file: str,
                 chromosome: str,
                 known_states: List[str]):
        '''
        Read in labeled file and store resulting table and labels
        '''
        self.info_string_symbols = list('.-_npbcxNPBCX')

        self.label_prefixes = ['match_nongap',
                               'num_sites_nongap',
                               'match_hmm',
                               'match_nonmask',
                               'num_sites_nonmask']

        self.data, self.labels = read_table.read_table_columns(
            labeled_file,
            sep='\t',
            group_by='strain',
            chromosome=chromosome)

        if self.labels[0] != 'region_id':
            err = 'Unexpected labeled format'
            log.exception(err)
            raise ValueError(err)

        for strain, data in self.data.items():
            n = len(data['region_id'])

            for s in known_states:
                for lbl in self.label_prefixes:
                    data[f'{lbl}_{s}'] = [0] * n

            for s in self.info_string_symbols:
                data['count_' + s] = [0] * n

        self.labels += [f'{lbl}_{st}' for lbl in self.label_prefixes
                        for st in known_states]
        self.labels += ['count_' + x for x in self.info_string_symbols]
def test_read_table_columns_partition(mocker):
    # non existant column, no grouping
    table = StringIO('a,b,c,d\n' 'e,f,g,h\n' 'i,j,k,l\n' 'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked',
                                              ',',
                                              group_by='nothing')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'a': list('eim'),
        'b': list('fjn'),
        'c': list('gko'),
        'd': list('hlp')
    }
    assert labels == list('abcd')

    # group by a
    table = StringIO('a,b,c,d\n'
                     'e,f,g,h\n'
                     'e,g,h,i\n'
                     'i,j,k,l\n'
                     'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',', group_by='a')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'e': {
            'a': list('ee'),
            'b': list('fg'),
            'c': list('gh'),
            'd': list('hi')
        },
        'i': {
            'a': list('i'),
            'b': list('j'),
            'c': list('k'),
            'd': list('l')
        },
        'm': {
            'a': list('m'),
            'b': list('n'),
            'c': list('o'),
            'd': list('p')
        }
    }
    assert labels == list('abcd')

    # group by a, and filter on b
    table = StringIO('a,b,c,d\n'
                     'e,f,g,h\n'
                     'e,g,h,i\n'
                     'i,j,k,l\n'
                     'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked',
                                              ',',
                                              group_by='a',
                                              b='j')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'i': {
            'a': list('i'),
            'b': list('j'),
            'c': list('k'),
            'd': list('l')
        },
    }
    assert labels == list('abcd')
def test_read_table_columns_filter(mocker):
    # non existant key, no filter
    table = StringIO('a,b,c,d\n' 'e,f,g,h\n' 'i,j,k,l\n' 'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',', z='nothing')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'a': list('eim'),
        'b': list('fjn'),
        'c': list('gko'),
        'd': list('hlp')
    }
    assert labels == list('abcd')

    # filter no matches
    table = StringIO('a,b,c,d\n' 'e,f,g,h\n' 'i,j,k,l\n' 'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',', b='o')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {'a': [], 'b': [], 'c': [], 'd': []}
    assert labels == list('abcd')

    # filter single match
    table = StringIO('a,b,c,d\n' 'e,f,g,h\n' 'i,j,k,l\n' 'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',', a='e')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'a': list('e'),
        'b': list('f'),
        'c': list('g'),
        'd': list('h')
    }
    assert labels == list('abcd')

    # filter single match
    table = StringIO('a,b,c,d\n' 'e,f,g,h\n' 'i,j,k,l\n' 'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',', b='j')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'a': list('i'),
        'b': list('j'),
        'c': list('k'),
        'd': list('l')
    }
    assert labels == list('abcd')

    # multiple filters
    table = StringIO('a,b,c,d\n'
                     'e,j,k,l\n'
                     'e,f,g,h\n'
                     'i,j,k,l\n'
                     'e,j,l,m\n'
                     'm,n,o,p\n')
    mocked_file = mocker.patch('misc.read_table.open', return_value=table)
    d, labels = read_table.read_table_columns('mocked', ',', b='j', a='e')

    mocked_file.assert_called_with('mocked', 'r')
    assert d == {
        'a': list('ee'),
        'b': list('jj'),
        'c': list('kl'),
        'd': list('lm')
    }
    assert labels == list('abcd')
    get_ref_gene_seq(gene, ref_gene_coords_fn, ref_seq_fn)
query_fn = gene + '.txt'
f = open(query_fn, 'w')
f.write(ref_gene_seq + '\n')
f.close()

print('getting gene sequences from all strains')
gp_dir = '../'
s = align_helpers.get_strains(align_helpers.flatten(gp.non_ref_dirs.values()))
ref_ind_to_strain_ind = {}
strain_ind_to_ref_ind = {}
for strain, d in s:
    print('*', strain)
    sys.stdout.flush()
    t, labels = read_table.read_table_columns(
        gp.analysis_out_dir_absolute + tag + '/' +
        'site_summaries/predictions_' + strain + '_chr' + chrm +
        '_site_summary.txt.gz', '\t')
    ref_ind_to_strain_ind[strain] = dict(zip(t['ps_ref'], t['ps_strain']))
    strain_ind_to_ref_ind[strain] = dict(zip(t['ps_strain'], t['ps_ref']))
# for par reference which doesn't have site summary file
align_fn = gp_dir + gp.alignments_dir + \
           '_'.join(gp.alignment_ref_order) + '_chr' + chrm + \
           '_mafft' + gp.alignment_suffix
t = get_inds_from_alignment(align_fn, True)
other_ref_strain = gp.ref_fn_prefix[gp.alignment_ref_order[1]]
ref_ind_to_strain_ind[other_ref_strain] = dict(zip(t['ps_ref'],
                                                   t['ps_strain']))
strain_ind_to_ref_ind[other_ref_strain] = dict(zip(t['ps_strain'],
                                                   t['ps_ref']))
s.append((other_ref_strain, gp.ref_dir[gp.alignment_ref_order[1]]))
strain_gene_seqs = get_gene_seqs(query_fn, s, chrm, ref_seq_fn, ref_start,
def main():

    args = read_args.process_predict_args(sys.argv[2:])

    task_ind = int(sys.argv[1])
    species_ind = task_ind

    species_from = args['states'][species_ind]

    base_dir = gp.analysis_out_dir_absolute + args['tag']

    regions_dir = f'{base_dir}/regions/'
    if not os.path.isdir(regions_dir):
        os.mkdir(regions_dir)

    quality_writer = None
    positions = gzip.open(f'{base_dir}/positions_{args["tag"]}.txt.gz', 'rt')
    line_number = 0

    region_writer = gzip.open(
        f'{regions_dir}{species_from}{gp.fasta_suffix}.gz', 'wt')
    region_index = {}

    for chrm in gp.chrms:
        # region_id strain chromosome predicted_species start end num_non_gap
        regions_chrm, labels = read_table.read_table_columns(
            f'{base_dir}/blocks_{species_from}_{args["tag"]}_labeled.txt',
            '\t',
            group_by='strain',
            chromosome=chrm
        )

        for strain in regions_chrm:
            n = len(regions_chrm[strain]['region_id'])

            for s in args['known_states']:
                regions_chrm[strain]['match_nongap_' + s] = [0] * n
                regions_chrm[strain]['num_sites_nongap_' + s] = [0] * n
                regions_chrm[strain]['match_hmm_' + s] = [0] * n
                regions_chrm[strain]['match_nonmask_' + s] = [0] * n
                regions_chrm[strain]['num_sites_nonmask_' + s] = [0] * n

            info_string_symbols = list('.-_npbcxNPBCX')
            for s in info_string_symbols:
                regions_chrm[strain]['count_' + s] = [0] * n

        # get masked sites for all references, not just the current
        # species_from we're considering regions from
        masked_sites_refs = {}
        for s, state in enumerate(args['known_states']):
            masked_sites_refs[s] = \
                convert_intervals_to_sites(
                    read_masked_intervals(
                        f'{gp.mask_dir}{state}'
                        f'_chr{chrm}_intervals.txt'))

        # loop through chromosomes and strains, followed by species of
        # introgression so that we only have to read each alignment in once
        # move to last read chromosome
        positions.seek(line_number)
        line = positions.readline()
        while line != '':
            line = line.split('\t')

            current_chrm = line[1]
            if current_chrm != chrm:
                break

            strain = line[0]
            if strain not in regions_chrm:
                # record current position in case need to re read line
                line_number = positions.tell()
                line = positions.readline()
                continue

            print(strain, chrm)

            # indices of alignment columns used by HMM
            ps = np.array([int(x) for x in line[2:]])

            headers, seqs = read_fasta.read_fasta(
                args['setup_args']['alignments_directory'] + \
                '_'.join(args['known_states'])
                + f'_{strain}_chr{chrm}_mafft{gp.alignment_suffix}')

            # to go from index in reference seq to index in alignment
            ind_align = []
            for seq in seqs:
                ind_align.append(index_alignment_by_reference(seq))
            
            masked_sites = convert_intervals_to_sites(
                read_masked_intervals(
                    f'{gp.mask_dir}{strain}_chr{chrm}_intervals.txt'))

            masked_sites_ind_align = []
            for s in range(len(args['known_states'])):
                masked_sites_ind_align.append(
                    ind_align[s][masked_sites_refs[s]])

            # add in sequence of query strain
            masked_sites_ind_align.append(
                ind_align[-1][masked_sites])

            # convert position indices from indices in master reference to
            # indices in alignment
            ps_ind_align = ind_align[0][ps]

            # loop through all regions for the specified chromosome and the
            # current strain
            for i in range(len(regions_chrm[strain]['region_id'])):
                r_id = regions_chrm[strain]['region_id'][i]
                start = regions_chrm[strain]['start'][i]
                end = regions_chrm[strain]['end'][i]

                # calculate:
                # - identity with each reference
                # - fraction of region that is gapped/masked

                # index of start and end of region in aligned sequences
                slice_start = ind_align[0][int(start)]
                slice_end = ind_align[0][int(end)]
                assert slice_start in ps_ind_align, \
                    f'{slice_start} {start} {r_id}'
                assert slice_end in ps_ind_align, \
                    f'{slice_end} {end} {r_id}'

                seqx = seqs[-1][slice_start:slice_end + 1]
                len_seqx = slice_end - slice_start + 1
                len_states = len(args['known_states'])

                # . = all match
                # - = gap in one or more sequences
                # p = matches predicted reference

                info = {'gap_any_flag': np.zeros((len_seqx), bool),
                        'mask_any_flag': np.zeros((len_seqx), bool),
                        'unseq_any_flag': np.zeros((len_seqx), bool),
                        'hmm_flag': np.zeros((len_seqx), bool),
                        'gap_flag': np.zeros((len_seqx, len_states), bool),
                        'mask_flag': np.zeros((len_seqx, len_states), bool),
                        'unseq_flag': np.zeros((len_seqx, len_states), bool),
                        'match_flag': np.zeros((len_seqx, len_states), bool)}

                for sj, statej in enumerate(args['known_states']):
                    seqj = seqs[sj][slice_start:slice_end+1]

                    # only alignment columns used by HMM (polymorphic, no
                    # gaps in any strain)
                    total_match_hmm, total_sites_hmm, infoj = \
                        seq_id_hmm(seqj, seqx, slice_start, ps_ind_align)

                    if statej == species_from \
                            or species_ind >= len(args['known_states']):
                        regions_chrm[strain]['num_sites_hmm'][i] = \
                            total_sites_hmm

                    # only write once, the first index
                    if sj == 0:
                        info['hmm_flag'] = infoj['hmm_flag']

                    info['gap_any_flag'] = np.logical_or(
                        info['gap_any_flag'], infoj['gap_flag'])
                    info['unseq_any_flag'] = np.logical_or(
                        info['unseq_any_flag'], infoj['unseq_flag'])
                    info['gap_flag'][:, sj] = infoj['gap_flag']
                    info['unseq_flag'][:, sj] = infoj['unseq_flag']
                    info['match_flag'][:, sj] = infoj['match']

                    regions_chrm[strain][f'match_hmm_{statej}'][i] = \
                        total_match_hmm

                    # all alignment columns, excluding ones with gaps in
                    # these two sequences
                    total_match_nongap, total_sites_nongap = \
                        seq_functions.seq_id(seqj, seqx)

                    regions_chrm[strain][f'match_nongap_{statej}'][i] =\
                        total_match_nongap
                    regions_chrm[strain][f'num_sites_nongap_{statej}'][i] =\
                        total_sites_nongap

                    # all alignment columns, excluding ones with gaps or
                    # masked bases or unsequenced in *these two sequences*
                    total_match_nonmask, total_sites_nonmask, infoj = \
                        seq_id_unmasked(seqj, seqx, slice_start,
                                        masked_sites_ind_align[sj],
                                        masked_sites_ind_align[-1])

                    info['mask_any_flag'] = np.logical_or(
                        info['mask_any_flag'], infoj['mask_flag'])
                    info['mask_flag'][:, sj] = infoj['mask_flag']

                    regions_chrm[strain][f'match_nonmask_{statej}'][i] = \
                        total_match_nonmask
                    regions_chrm[strain][f'num_sites_nonmask_{statej}'][i] = \
                        total_sites_nonmask

                region_index[int(r_id[1:])] = region_writer.tell()
                region_writer.write(f'#{r_id}\n')
                names = args['known_states'] + [strain]
                for sj in range(len(names)):
                    # write sequence to region alignment file, along with
                    # start and end coordinates
                    startj = bisect.bisect_left(ind_align[sj], slice_start)
                    endj = bisect.bisect_left(ind_align[sj], slice_end)

                    region_writer.write(f'> {names[sj]} {startj} {endj}\n')
                    region_writer.write(
                        ''.join(seqs[sj][slice_start:slice_end+1]) + '\n')

                # also write string with info about each site
                info_string = make_info_string(info, 0, species_ind)
                region_writer.write('> info\n')
                region_writer.write(info_string + '\n')

                # TODO this can be made faster with numpy
                # and keep track of each symbol count
                for sym in info_string_symbols:
                    regions_chrm[strain]['count_' + sym][i] = \
                        info_string.count(sym)

            # record current position in case need to re read line
            line_number = positions.tell()
            line = positions.readline()
            sys.stdout.flush()

        labels += ['match_nongap_' + x for x in args['known_states']]
        labels += ['num_sites_nongap_' + x for x in args['known_states']]
        labels += ['match_hmm_' + x for x in args['known_states']]
        labels += ['match_nonmask_' + x for x in args['known_states']]
        labels += ['num_sites_nonmask_' + x for x in args['known_states']]
        labels += ['count_' + x for x in info_string_symbols]

        assert labels[0] == 'region_id', 'Unexpected labeled format'

        # write on first execution
        if quality_writer is None:
            quality_writer = open(f'{base_dir}/blocks_{species_from}'
                                  f'_{args["tag"]}_quality.txt', 'w')

            quality_writer.write('\t'.join(labels) + '\n')

        # reorganize output as list of tuples ordered by label
        output = []
        strains = list(regions_chrm.keys())
        for strain in strains:
            # pop to limit memory usage
            d = regions_chrm.pop(strain)
            output += list(zip(*[d[l] for l in labels]))

        # sort by region id (index 0, remove r)
        for entry in sorted(output, key=lambda e: int(e[0][1:])):
            quality_writer.write('\t'.join([str(e) for e in entry]) + '\n')

    quality_writer.close()
    region_writer.close()
    with open(f'{regions_dir}{species_from}.pkl', 'wb') as index:
        pickle.dump(region_index, index)