Example #1
0
 def get_matrix_topics_for_dec(self):
     from sklearn.feature_extraction.text import TfidfTransformer
     matrix, topics = self.get_matrix_topics(using='tf')
     topics = np.array(au.reindex(topics))
     matrix = TfidfTransformer(norm='l2', sublinear_tf=True).fit_transform(matrix)
     matrix = matrix.astype(np.float32)
     print(matrix.shape, matrix.dtype, matrix.size)
     matrix = np.asarray(matrix.todense()) * np.sqrt(matrix.shape[1])
     print('todense succeed')
     p = np.random.permutation(matrix.shape[0])
     matrix = matrix[p]
     topics = topics[p]
     print('permutation finished')
     assert matrix.shape[0] == topics.shape[0]
     return matrix, topics
Example #2
0
class GenomeDataset_v2(Dataset):
    '''
    Metagenomics dataset for reading simulated data in fasta format (.fna)
    '''
    HASH_PATTERN = r'\([a-f0-9]{40}\)'

    def __init__(self,
                 fna_file,
                 feature_type='bow',
                 k_mer=4,
                 return_raw=False,
                 use_tfidf=True,
                 not_return_label=False):
        '''
        Args:
            k_mer: number of nucleotid to combine into a word.
            overlap_k_mer: True to extract overlapping k_mer from a genome string. False otherwise.
            fna_file: path to fna file (fasta format).
            transform: transformation applied to all samples.
        '''
        assert os.path.exists(fna_file), '{} does not exists'.format(fna_file)
        self.data = []
        self.label = []
        self.is_raw = return_raw
        self.vocab = generate_k_mer_corpus(k_mer)
        self._len = 0
        with open(fna_file, 'r') as g_file:
            lines = g_file.readlines()
            lines = [line.strip() for line in lines]
            gene_str = ''
            hash_label = ''
            for line in lines:
                # Catch new sequence
                if line[0] == '>':

                    # Update hash label key with gene sting value
                    if hash_label != '':
                        # self.match_dict[hash_label].append(ensure_gene_length(k_mer, gene_str))
                        gene_str = ensure_gene_length(k_mer, gene_str)
                        gene_str = self.tokensize_gene_str(gene_str)
                        self.data.append(gene_str)
                        self.label.append(hash_label)

                        # Track the number of genes
                        self._len += 1

                    # Reset hash_label for reading new sequence
                    hash_label = ''
                    gene_str = ''
                    dot_pos = line.find('.')
                    # Seq_flag indicate 1st or 2nd sequence
                    seq_flag = int(line[dot_pos + 1])

                    # 1st sequence, read the hash value (indicate the label)
                    if seq_flag == 1:
                        hash_pattern = re.search(GenomeDataset.HASH_PATTERN,
                                                 line)
                        if hash_pattern is not None:
                            # for res in hash_pattern:
                            hash_label = hash_pattern.group(0)

                            # Remove the brackets
                            hash_label = hash_label.replace('(', '')
                            hash_label = hash_label.replace(')', '')
                    else:
                        pass  # Ignore 2nd sequence for now
                # Gene string
                else:
                    gene_str = gene_str + line

        count_vectorizer = CountVectorizer(self.data)
        self.numeric_data = count_vectorizer.fit_transform(self.data)

        if use_tfidf:
            self.numeric_data = TfidfTransformer(
                norm='l2', sublinear_tf=True).fit_transform(self.numeric_data)
            print('Finished TFIDF.')

        self.numeric_data = np.asarray(self.numeric_data.todense()) * np.sqrt(
            self.numeric_data.shape[1])
        self.numeric_data = normalize(self.numeric_data, norm='l2')
        self.numeric_data = self.numeric_data.astype('float32')

        self.lb_mapping = self.to_onehot_mapping_2(set(self.label))
        self.not_return_label = not_return_label

    def tokensize_gene_str(self, x: str):
        res_str = ''
        for i in range(len(x) - 4):
            sub_k_mer_str = x[i:i + 4]
            res_str += (' ' + sub_k_mer_str)

        return res_str[1:]

    def to_onehot_mapping_2(self, lb_list):
        lb_mapping = dict()
        for i, lb in enumerate(lb_list):
            lb_mapping[lb] = i

        return lb_mapping

    def __len__(self):
        # Return len of dataset in number of gene strings
        return self._len

    def __getitem__(self, idx):
        data = self.data[idx] if self.is_raw else self.numeric_data[idx]
        raw_lb = self.label[idx]
        lb = self.lb_mapping[raw_lb]

        if self.not_return_label:
            return (data, data)
        return (data, lb)