def parse_idx(idx_summary, bam_regex):
    '''
    produces a table with read counts: columns = count types -> 1 column for mapped reads, rows = samples
    idx_summary: summary filw of idxstats on bamfiles
    bam_regex: regular expression to match bam files, group 1 -> samplename
    '''

    # map sample name to list counts
    sample_to_mapped = defaultdict(lambda: 0)

    #read idxstat table and update dictionary
    idxTable = tf.readTable(idx_summary, sep='\t', header=True)
    for row in range(0, idxTable.rowNum()):
        bam_match = re.search(bam_regex, idxTable.get(row, 4))
        if (bam_match):
            sample = bam_match.group(1)
            count = int(idxTable.get(row, 2))
            sample_to_mapped[sample] += count

    # resulting table
    counts = Table()
    counts.addColumn(str, 'sample', None)
    counts.addColumn(int, 'mapped', 0)

    # transform content of the dictionary into the resulting table
    for sample, readnr in sorted(sample_to_mapped.items()):
        counts.addRow([sample, readnr])

    return counts
def selectRows(table, selectionFunction):
    '''
    method to create a new table with a subset of rows from 'table'
    selectionFunction: input table+rowindex -> output: true/false, true <-> select column for new table 
    '''

    # create new table with same columns as 'table'
    selection_table = Table(table)

    # iterate over all rows of 'table' and add them to the new table if they meet the selection criteria
    for rowPos in range(0, table.rowNum()):
        if (selectionFunction(table, rowPos)):
            selection_table.addRow(table.getRow(rowPos))

    return selection_table
def parse_fastqc(fastqc_summary, raw_regex, trim_regex):
    '''
    produces a table with read counts: columns = count types -> raw and trimmed, rows = samples
    fastqc_summary: summary file of fastqc
    raw_regex: regular expression to match raw fastq files, group 1 -> samplename
    trim_regex: regular expression to match trimmed fastq files, group 1 -> samplename
    '''

    # map sample name to list (raw, trimmed) counts
    sample_to_counts = defaultdict(lambda: [0, 0])

    # read fastqc statistics and update dictionary
    qcTable = tf.readTable(fastqc_summary, sep='\t', header=True)
    for rowInd in range(0, qcTable.rowNum()):
        if (qcTable.get(rowInd, 0) == 'Total Sequences'):
            # get filename of the row and compare it against the regular expression for trimmed and raw fastq files to find the sample name
            filename = qcTable.get(rowInd, 2)
            readcount = int(qcTable.get(rowInd, 1))
            raw_match = re.search(raw_regex, filename)
            trim_match = re.search(trim_regex, filename)
            # add readcount if samplename is found
            if (raw_match):
                sample_to_counts[raw_match.group(1)][0] += readcount
            elif (trim_match):
                sample_to_counts[trim_match.group(1)][1] += readcount

    # resulting table
    counts = Table()
    counts.addColumn(str, 'sample', None)
    counts.addColumn(int, 'raw', 0)
    counts.addColumn(int, 'trimmed', 0)

    # transform content of the dictionary into the resulting table
    for sample, readnrs in sorted(sample_to_counts.items()):
        counts.addRow([sample, readnrs[0], readnrs[1]])

    return counts
def parse_chrom_file(chrom_file, idx_summary, bam_regex):
    '''
    produces a table with read counts: columns = mapped reads for chromosome groups, rows = samples
    idx_summary: summary filw of idxstats on bamfiles
    bam_regex: regular expression to match bam files, group 1 -> samplename
    chrom_file: tab-separated table with chromosome name (col 0) and organism/group name (col 1)
    '''

    # read assignment of chromosomes to groups as dictionary
    chrom_to_group = {}
    chr_tab = tf.readTable(chrom_file, sep='\t', header=True, headerstart='#')
    for row in range(0, chr_tab.rowNum()):
        chrom_to_group[chr_tab.get(row, 0)] = chr_tab.get(row, 1)

    # map sample name to map group -> counts
    sample_to_mapped = defaultdict(lambda: defaultdict(int))

    #read idxstat table and update dictionary
    idxTable = tf.readTable(idx_summary, sep='\t', header=True)
    for row in range(0, idxTable.rowNum()):
        bam_match = re.search(bam_regex, idxTable.get(row, 4))
        chr_name = idxTable.get(row, 0)
        if (bam_match and chr_name in chrom_to_group):
            sample = bam_match.group(1)
            count = int(idxTable.get(row, 2))
            sample_to_mapped[sample][chrom_to_group[chr_name]] += count

    # resulting table
    counts = Table()
    counts.addColumn(str, 'sample', None)
    groups = sorted(list(set(chrom_to_group.values())))
    for val in groups:
        counts.addColumn(int, val, 0)

    # transform content of the dictionary into the resulting table
    for sample, readnr_dict in sorted(sample_to_mapped.items()):
        newRow = [sample]
        for val in groups:
            newRow.append(readnr_dict[val])
        counts.addRow(newRow)

    return counts
def selectColumns(table, colList):
    '''
    creates a new table with all columns in colList in the given order
    '''

    # create new table
    selection_table = Table()

    # create columns of new table
    for column in colList:
        selection_table.addColumn(table.getColumnType(column),
                                  table.getColumnName(column))

    # iterate over all rows of 'table' and extract the data for the columns of 'colList'
    for rowPos in range(0, table.rowNum()):
        selection_table.addRow(table.getRow(rowPos, select_cols=colList))

    return selection_table
def groupBy(table,
            groupColumns,
            aggregateColumns,
            aggregateFunctions,
            append_aggregation_func_to_colname=False):
    '''
    'groupColumns': list of column names or indices used for forming groups, aggregates all rows if set to []
    'aggregateColumns': list of column names that are aggregated into a single value per column and group
    'aggregateFunctions': list of strings giving aggregation functions
    'append_aggregation_func_to_colname': defines if column names of aggregate columns are modified in the returned table
    '''

    ## check input

    # check if groupColumns and aggregateColumns are passed as single element -> transform into iterables
    if (isinstance(groupColumns, int) or isinstance(groupColumns, str)):
        groupColumns = [groupColumns]
    if (isinstance(aggregateColumns, int)
            or isinstance(aggregateColumns, str)):
        aggregateColumns = [aggregateColumns]
    if (isinstance(aggregateFunctions, str)):
        aggregateFunctions = [aggregateFunctions]

    # transform aggregateFunctions given as strings to function pointers
    aggregate_fps = []
    for af_name in aggregateFunctions:
        if (af_name == 'min'):
            aggregate_fps.append(_min)
        elif (af_name == 'sum'):
            aggregate_fps.append(_sum)
        else:
            raise ValueError('Unknown aggregation function: ' + str(af_name))

    ## mapping of groups to row indices

    # map group as tuple of values -> list of row indices as 0-based ints
    group_to_rows = {}
    for rowInd in range(0, table.rowNum()):

        # get group of the current row
        group = []
        for col in groupColumns:
            group.append(table.get(rowInd, col))
        group = tuple(group)

        # add row to the mapping
        if (not group in group_to_rows):
            group_to_rows[group] = [rowInd]
        else:
            group_to_rows[group].append(rowInd)

    ## set up return table

    groupTable = Table()
    # add columns defining the groups
    for gc in groupColumns:
        groupTable.addColumn(table.getColumnType(gc), table.getColumnName(gc))
    # add columns for aggregated values
    for ac, af in zip(aggregateColumns, aggregateFunctions):
        if (append_aggregation_func_to_colname is False):
            groupTable.addColumn(table.getColumnType(ac),
                                 table.getColumnName(ac))
        else:
            groupTable.addColumn(table.getColumnType(ac),
                                 table.getColmnName(ac) + '_' + af)

    ##  perform aggregations

    for cur_group, cur_rowlist in sorted(group_to_rows.items()):
        row = []
        for group_element in cur_group:
            row.append(group_element)
        for ac, af in zip(aggregateColumns, aggregate_fps):
            aggregated_value = af(table, cur_rowlist, ac)
            row.append(aggregated_value)
        groupTable.addRow(row)

    return groupTable
def joinTables(tab1, tab2, joinCols, joinType='inner'):
    '''
    join 2 tables tab1 and tab2
    joinCols: list of tuples (col_1, col_2) to compare when joining, col_i: column position or name from tab_i,
    join type: 'inner', 'leftouter' and 'fullouter'
    '''

    # TODO: check types of columns

    # TODO: implement outer join types: rightouter

    # set up configuration of joined table -> column names and types combined

    jcols1, jcols2 = zip(*joinCols)
    joined_table = Table()

    # take all columns of the first table
    for colPos, colName in enumerate(tab1.getColumnNames(noneHandling=False)):
        if (colName is None):
            joined_table.addColumn(tab1.getColumnType(colPos))
        else:
            joined_table.addColumn(tab1.getColumnType(colPos), colName)

    # take all columns except the jcols2
    jcol_set = set(jcols2)
    # indices of columns to keep
    cols_keep = []
    for colPos, colName in enumerate(tab2.getColumnNames()):
        if ((not colPos in jcol_set) and (not colName in jcol_set)):
            if (colName is None):
                joined_table.addColumn(tab2.getColumnType(colPos))
            else:
                joined_table.addColumn(tab2.getColumnType(colPos), colName)
            cols_keep.append(colPos)

    # build indices mapping values of joinCols to row index positions for each table

    # index = mapping jcol2 values to row indices in tab2
    index = {}
    for i in range(0, tab2.rowNum()):
        keyInd = tuple(map(lambda c: tab2.get(i, c), jcols2))
        if (not keyInd in index):
            index[keyInd] = [i]
        else:
            index[keyInd].append(i)
    # seenKeys = all jcol2 values with a join partner in table1 -> use for full outer join
    seenKeys = set()

    # join rows and fill joined table

    # iterate over table 1 and and join rows
    for i in range(0, tab1.rowNum()):
        index_key = tuple(map(lambda c: tab1.get(i, c), jcols1))
        # inner join: all pairs of rows with matching join-condition
        if (index_key in index):
            for row2Ind in index[index_key]:
                newRow = tab1.getRow(i)
                for pos in cols_keep:
                    newRow.append(tab2.get(row2Ind, pos))
                joined_table.addRow(newRow)
            seenKeys.add(index_key)

        # left outer join & full outer join: keep rows from table 1 without join partner
        elif (joinType == 'leftouter' or joinType == 'fullouter'):
            newRow = tab1.getRow(i)
            for _ in cols_keep:
                newRow.append(None)
            joined_table.addRow(newRow)

    # full outer join: add rows from table 2 without join partner
    if (joinType == 'fullouter'):
        for index_key in index.keys():
            if not index_key in seenKeys:
                for rowInd in index[index_key]:
                    newRow = []
                    # no join partner from table 1 -> set columns from table 1 to None
                    for _ in range(0, tab1.colNum()):
                        newRow.append(None)
                    # join columns only once in output -> set values from table2 join Columns for table1 join Columns
                    for keyCol1, keyCol2 in joinCols:
                        keyCol1_pos = tab1._check_table_access_column(
                            keyCol1, 'join')
                        newRow[keyCol1_pos] = tab2.get(rowInd, keyCol2)
                    # add data from table 2
                    for col in cols_keep:
                        newRow.append(tab2.get(rowInd, col))
                    joined_table.addRow(newRow)

    return joined_table
def readTable(fileName,
              sep='\t',
              header=True,
              headerstart='',
              colsToRead=None,
              tableEnd=None,
              comment=None,
              split_quoted_cell=True,
              quote_symbol='"'):
    '''
    function to read a table from a file
    fileName: name of the file to read
    sep: character separating the columns of the table
    header: indicating if the table has a header with column names (first line)
    headerstart: sign at the beginning the header of the table (first line) that will be removed
    colsToRead: a list of indices of column positions to read, giving also their order
    tableEnd: regex for recognizing the end of the table
    comment: sign at linestart indicating a comment -> is not read 
    split_quoted_cell: option for parsing csv files with cells containing ',' symbols which should not be interpreted as cell separators
    quote_symbol: symbol to recognize quoted cells if split_quoted_cell = False
    '''

    table = Table()

    with open(fileName, 'r') as reader:

        # column names in headerS
        if (header == True):
            headerline = reader.readline().strip('\n')
            # remove symbol marking the headerline
            headerline = re.sub('^' + headerstart, '', headerline)
            colNames = _split_row_into_cells(
                headerline, sep, split_quoted_cell,
                quote_symbol)  #headerline.split(sep)

            # read all columns
            if colsToRead is None:
                for colName in colNames:
                    # make all columns of type string
                    table.addColumn(str, colName)

            # read specific columns given as column position or column name
            else:
                # create a copy of the list because it gets modified
                colsToRead = colsToRead.copy()
                for (pos, readCol) in enumerate(colsToRead):
                    # column position given -> get column name from header
                    if (isinstance(readCol, int)):
                        table.addColumn(str, colNames[readCol])
                    # column name given -> check with header and extract column position
                    elif (isinstance(readCol, str)):
                        if (readCol in colNames):
                            table.addColumn(str, readCol)
                            colsToRead[pos] = colNames.index(readCol)
                        else:
                            raise ValueError(
                                'Cannot read column ' + readCol +
                                ' because it is not part of the table header!')
                    else:
                        raise ValueError(
                            'Cannot read column ' + repr(readCol) +
                            ' because it is no integer or string!')

        # no header -> create nameless columns
        else:
            # read all columns
            if colsToRead is None:
                firstline = reader.readline().strip('\n')
                while (comment is not None and firstline.startswith(comment)):
                    firstline = reader.readline().strip('\n')
                firstRow = _split_row_into_cells(
                    firstline, sep, split_quoted_cell,
                    quote_symbol)  #firstline.split(sep)
                for _ in range(0, len(firstRow)):
                    table.addColumn(str)
                table.addRow(firstRow)
            # read specific columns
            else:
                for _ in range(0, len(colsToRead)):
                    table.addColumn(str)

        for line in reader:
            line = line.strip('\n')

            # skip comment lines
            if (comment is not None and line.startswith(comment)):
                continue

            # check for table end
            if (tableEnd is not None and re.search(tableEnd, line)):
                break

            # add current row to table
            rowContent = _split_row_into_cells(line, sep, split_quoted_cell,
                                               quote_symbol)  #line.split(sep)
            # add all columns of the row
            if (colsToRead is None):
                table.addRow(rowContent)
            # add only specific columns of the row in a given order
            else:
                reducedRowContent = []
                for pos in colsToRead:
                    reducedRowContent.append(rowContent[pos])
                table.addRow(reducedRowContent)

    return table