コード例 #1
0
def iterator(dataset_fn):
    """
    Provides an iterator of parsed documents from the 20 Newsgroups dataset.

    :param dataset_fn: Path to Newsgroups dataset archive file.
    :type dataset_fn: unicode|str
    :rtype : generator
    """
    ng = datasets.fetch_20newsgroups()

    for article, group, target, filename in zip(
            ng['data'], [ng['target_names'][x] for x in ng['target']],
            ng['target'], ng['filenames']):
        article = twenty_newsgroups.strip_newsgroup_header(article)
        article = twenty_newsgroups.strip_newsgroup_footer(article)
        article = twenty_newsgroups.strip_newsgroup_quoting(article)
        doc_id = os.path.basename(filename)

        yield {
            'doc_id': doc_id,
            'article': article,
            'group': group,
            'target': target,
            'filename': filename
        }
コード例 #2
0
    def transform(self, posts):
        features = np.recarray(shape=(len(posts), ),
                               dtype=[('subject', object), ('body', object)])
        for i, text in enumerate(posts):
            headers, _, bod = text.partition('\n\n')
            bod = strip_newsgroup_footer(bod)
            bod = strip_newsgroup_quoting(bod)
            features['body'][i] = bod

            prefix = 'Subject:'
            sub = ''
            for line in headers.split('\n'):
                if line.startswith(prefix):
                    sub = line[len(prefix):]
                    break
            features['subject'][i] = sub
        return features
コード例 #3
0
    def transform(self, posts):
        features = np.recarray(shape=(len(posts),),
                               dtype=[('subject', object), ('body', object)])
        for i, text in enumerate(posts):
            headers, _, bod = text.partition('\n\n')
            bod = strip_newsgroup_footer(bod)
            bod = strip_newsgroup_quoting(bod)
            features['body'][i] = bod

            prefix = 'Subject:'
            sub = ''
            for line in headers.split('\n'):
                if line.startswith(prefix):
                    sub = line[len(prefix):]
                    break
            features['subject'][i] = sub

        return features
コード例 #4
0
def iterator(dataset_fn):
    """
    Provides an iterator of parsed documents from the 20 Newsgroups dataset.

    :param dataset_fn: Path to Newsgroups dataset archive file.
    :type dataset_fn: unicode|str
    :rtype : generator
    """
    ng = datasets.fetch_20newsgroups()

    for article, group, target, filename in zip(ng['data'], [ng['target_names'][x] for x in ng['target']],
                                                ng['target'], ng['filenames']):
        article = twenty_newsgroups.strip_newsgroup_header(article)
        article = twenty_newsgroups.strip_newsgroup_footer(article)
        article = twenty_newsgroups.strip_newsgroup_quoting(article)
        doc_id = os.path.basename(filename)

        yield {'doc_id': doc_id, 'article': article, 'group': group, 'target': target, 'filename': filename}
コード例 #5
0
    def transform(self, posts):
        # construct object dtype array with two columns
        # first column = 'subject' and second column = 'body'
        features = np.empty(shape=(len(posts), 2), dtype=object)
        for i, text in enumerate(posts):
            headers, _, bod = text.partition("\n\n")
            bod = strip_newsgroup_footer(bod)
            bod = strip_newsgroup_quoting(bod)
            features[i, 1] = bod

            prefix = "Subject:"
            sub = ""
            for line in headers.split("\n"):
                if line.startswith(prefix):
                    sub = line[len(prefix):]
                    break
            features[i, 0] = sub

        return features
コード例 #6
0
    def transform(self, posts):
        # construct object dtype array with two columns
        # first column = 'subject' and second column = 'body'
        features = np.empty(shape=(len(posts), 2), dtype=object)
        for i, text in enumerate(posts):
            headers, _, bod = text.partition('\n\n')
            bod = strip_newsgroup_footer(bod)
            bod = strip_newsgroup_quoting(bod)
            features[i, 1] = bod

            prefix = 'Subject:'
            sub = ''
            for line in headers.split('\n'):
                if line.startswith(prefix):
                    sub = line[len(prefix):]
                    break
            features[i, 0] = sub

        return features
コード例 #7
0
ファイル: fetch.py プロジェクト: chrisji/cl-quant
def fetch_sraa(
        DOWNLOAD_URL='http://people.cs.umass.edu/~mccallum/data/sraa.tar.gz',
        dataset_home='../datasets/SRAA',
        dopickle=False):
    """
    Fetchs a version of Reuters21578 for cross-domain adaptation, as defined in:
    Dai, W., Xue, G. R., Yang, Q., & Yu, Y. (2007, August).
    Co-clustering based classification for out-of-domain documents.
    In Proceedings of the 13th ACM SIGKDD international conference on Knowledge discovery and data mining (pp. 210-219). ACM.
    """
    picklepath = join(dataset_home, 'sraa.pkl')
    if dopickle and exists(picklepath):
        print('...loading pickle from {}'.format(picklepath))
        return pickle.load(open(picklepath, 'rb'))

    dataset_path = join(dataset_home, 'sraa.tar.gz')

    if not exists(dataset_path):
        create_if_not_exist(dataset_home)
        print("downloading Reuters dataset (once and for all) into %s" %
              dataset_path)
        download_file(DOWNLOAD_URL, dataset_path)
        print("untarring dataset...")
        tarfile.open(dataset_path, 'r:gz').extractall(dataset_home)

    sraa = load_files(join(dataset_home, 'sraa'), encoding='latin1')
    remove = ('headers', 'footers')  #, 'quotes')
    if 'headers' in remove:
        sraa.data = [strip_newsgroup_header(text) for text in sraa.data]
    if 'footers' in remove:
        sraa.data = [strip_newsgroup_footer(text) for text in sraa.data]
    if 'quotes' in remove:
        sraa.data = [strip_newsgroup_quoting(text) for text in sraa.data]

    sraa = _index_by_label(sraa, min_words=10)

    if dopickle:
        print('...pickling the dataset into {} to speed-up next calls'.format(
            picklepath))
        pickle.dump(sraa, open(picklepath, 'wb'), pickle.HIGHEST_PROTOCOL)

    return sraa
コード例 #8
0
def preprocess_input(text):
    text = strip_newsgroup_header(text)
    text = strip_newsgroup_quoting(text)
    text = strip_newsgroup_footer(text)
    return text
コード例 #9
0
import seaborn as sns
from sklearn.datasets import fetch_20newsgroups
from sklearn.datasets.twenty_newsgroups import strip_newsgroup_header, strip_newsgroup_quoting, strip_newsgroup_footer
newsgroups_train = fetch_20newsgroups(subset='train')['data']
newsgroups_test = fetch_20newsgroups(subset='test')['data']
NUM_TOPICS = 30
NUM_NEIGHBORS = 10
BUCKET = 'sagemaker-sbn'
PREFIX = '20newsgroups'

# In[ ]:

for i in range(len(newsgroups_train)):
    newsgroups_train[i] = strip_newsgroup_header(newsgroups_train[i])
    newsgroups_train[i] = strip_newsgroup_quoting(newsgroups_train[i])
    newsgroups_train[i] = strip_newsgroup_footer(newsgroups_train[i])

# In[ ]:

newsgroups_train[1]

# In[ ]:

get_ipython().system('pip install nltk')
import nltk
nltk.download('punkt')
nltk.download('wordnet')
from nltk import word_tokenize
from nltk.stem import WordNetLemmatizer
import re
token_pattern = re.compile(r"(?u)\b\w\w+\b")