示例#1
0
文件: dataset.py 项目: HLT-ISTI/pydro
    def _load_wipo(self, class_mode, classlevel):
        assert class_mode in {'singlelabel','multilabel'}, 'available class_mode are sl (single-label) or ml (multi-label)'
        data_path = '../datasets/WIPO/wipo-gamma/en'
        data_proc = '../datasets/WIPO-extracted'

        devel = fetch_WIPOgamma(subset='train', classification_level=classlevel, data_home=data_path, extracted_path=data_proc, text_fields=['abstract'])
        test  = fetch_WIPOgamma(subset='test', classification_level=classlevel, data_home=data_path, extracted_path=data_proc, text_fields=['abstract'])

        devel_data = [d.text for d in devel]
        test_data  = [d.text for d in test]
        self.devel_raw, self.test_raw = mask_numbers(devel_data), mask_numbers(test_data)
        # self.devel_raw, self.test_raw = devel_data, test_data

        self.classification_type = class_mode
        if class_mode=='multilabel':
            devel_target = [d.all_labels for d in devel]
            test_target  = [d.all_labels for d in test]
            self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel_target, test_target)
            self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix
        else:
            devel_target = [d.main_label for d in devel]
            test_target  = [d.main_label for d in test]
            # only for labels with at least one training document
            class_id = {labelname:index for index,labelname in enumerate(sorted(set(devel_target)))}
            devel_target = np.array([class_id[id] for id in devel_target]).astype(int)
            test_target  = np.array([class_id.get(id,None) for id in test_target])
            if None in test_target:
                print(f'deleting {(test_target==None).sum()} test documents without valid categories')
                keep_pos = test_target!=None
                self.test_raw = (np.asarray(self.test_raw)[keep_pos]).tolist()
                test_target = test_target[keep_pos]
            test_target=test_target.astype(int)
            self.devel_target, self.test_target = devel_target, test_target
            self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(self.devel_target.reshape(-1, 1), self.test_target.reshape(-1, 1))
示例#2
0
文件: dataset.py 项目: HLT-ISTI/pydro
 def _load_imdb(self):
     data_path = '../datasets/IMDB'
     devel = fetch_IMDB(subset='train', data_home=data_path)
     test = fetch_IMDB(subset='test', data_home=data_path)
     self.classification_type = 'singlelabel'
     self.devel_raw, self.test_raw = mask_numbers(devel.data), mask_numbers(test.data)
     self.devel_target, self.test_target = devel.target, test.target
     self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(self.devel_target.reshape(-1,1), self.test_target.reshape(-1,1))
示例#3
0
文件: dataset.py 项目: HLT-ISTI/pydro
 def _load_20news(self):
     metadata = ('headers', 'footers', 'quotes')
     devel = fetch_20newsgroups(subset='train', remove=metadata)
     test = fetch_20newsgroups(subset='test', remove=metadata)
     self.classification_type = 'singlelabel'
     self.devel_raw, self.test_raw = mask_numbers(devel.data), mask_numbers(test.data)
     self.devel_target, self.test_target = devel.target, test.target
     self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(self.devel_target.reshape(-1,1), self.test_target.reshape(-1,1))
示例#4
0
文件: dataset.py 项目: HLT-ISTI/pydro
    def _load_rcv1(self):
        data_path = '../datasets/RCV1-v2/unprocessed_corpus' #TODO: check when missing
        devel = fetch_RCV1(subset='train', data_path=data_path)
        test = fetch_RCV1(subset='test', data_path=data_path)

        self.classification_type = 'multilabel'
        self.devel_raw, self.test_raw = mask_numbers(devel.data), mask_numbers(test.data)
        self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel.target, test.target)
        self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix
示例#5
0
文件: dataset.py 项目: HLT-ISTI/pydro
    def _load_reuters(self):
        data_path = os.path.join(get_data_home(), 'reuters21578')
        devel = fetch_reuters21578(subset='train', data_path=data_path)
        test = fetch_reuters21578(subset='test', data_path=data_path)

        self.classification_type = 'multilabel'
        self.devel_raw, self.test_raw = mask_numbers(devel.data), mask_numbers(test.data)
        self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel.target, test.target)
        self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix
示例#6
0
文件: dataset.py 项目: HLT-ISTI/pydro
 def _load_fasttext_data(self,name):
     data_path='../datasets/fastText'
     self.classification_type = 'singlelabel'
     name=name.replace('-','_')
     train_file = join(data_path,f'{name}.train')
     assert os.path.exists(train_file), f'file {name} not found, please place the fasttext data in {data_path}' #' or specify the path' #todo
     self.devel_raw, self.devel_target = load_fasttext_format(train_file)
     self.test_raw, self.test_target = load_fasttext_format(join(data_path, f'{name}.test'))
     self.devel_raw = mask_numbers(self.devel_raw)
     self.test_raw = mask_numbers(self.test_raw)
     self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(self.devel_target.reshape(-1, 1), self.test_target.reshape(-1, 1))
示例#7
0
文件: dataset.py 项目: HLT-ISTI/pydro
    def _load_jrc(self, version):
        assert version in ['300','all'], 'allowed versions are "300" or "all"'
        data_path = "../datasets/JRC_Acquis_v3"
        tr_years=list(range(1986, 2006))
        te_years=[2006]
        if version=='300':
            training_docs, tr_cats = fetch_jrcacquis(data_path=data_path, years=tr_years, cat_threshold=1,most_frequent=300)
            test_docs, te_cats = fetch_jrcacquis(data_path=data_path, years=te_years, cat_filter=tr_cats)
        else:
            training_docs, tr_cats = fetch_jrcacquis(data_path=data_path, years=tr_years, cat_threshold=1)
            test_docs, te_cats = fetch_jrcacquis(data_path=data_path, years=te_years, cat_filter=tr_cats)
        print(f'load jrc-acquis (English) with {len(tr_cats)} tr categories ({len(te_cats)} te categories)')

        devel_data = JRCAcquis_Document.get_text(training_docs)
        test_data = JRCAcquis_Document.get_text(test_docs)
        devel_target = JRCAcquis_Document.get_target(training_docs)
        test_target = JRCAcquis_Document.get_target(test_docs)

        self.classification_type = 'multilabel'
        self.devel_raw, self.test_raw = mask_numbers(devel_data), mask_numbers(test_data)
        self.devel_labelmatrix, self.test_labelmatrix = _label_matrix(devel_target, test_target)
        self.devel_target, self.test_target = self.devel_labelmatrix, self.test_labelmatrix