예제 #1
0
    def read_examples_from_file(fields, format: str, path):
        make_example = {
            'json': Example.fromJSON,
            'dict': Example.fromdict,
            'tsv': Example.fromCSV,
            'csv': Example.fromCSV
        }[format.lower()]
        lines = 0
        with open(os.path.expanduser(path), encoding="utf8") as f:
            for line in f:
                lines += 1
        with open(os.path.expanduser(path), encoding="utf8") as f:
            if format == 'csv':
                reader = unicode_csv_reader(f)
            elif format == 'tsv':
                reader = unicode_csv_reader(f, delimiter='\t')
            else:
                reader = f

            next(reader)

            examples = [
                make_example(line, fields) for line in pyprind.prog_bar(
                    reader,
                    iterations=lines,
                    title='\nReading and processing data from "' + path + '"')
            ]
        return examples
    def __init__(self, path, format, question_fix_length, fields, bert_fields, skip_header=False,
                 csv_reader_params={}, **kwargs):
        format = format.lower()
        make_example = {
            'json': Example.fromJSON, 'dict': Example.fromdict,
            'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format]
        make_bert_example = Example.fromlist

        with io.open(os.path.expanduser(path), encoding="utf8") as f:
            if format == 'csv':
                reader = unicode_csv_reader(f, **csv_reader_params)
            elif format == 'tsv':
                reader = unicode_csv_reader(f, delimiter='\t', **csv_reader_params)
            else:
                reader = f

            if format in ['csv', 'tsv'] and isinstance(fields, dict):
                if skip_header:
                    raise ValueError('When using a dict to specify fields with a {} file,'
                                     'skip_header must be False and'
                                     'the file must have a header.'.format(format))
                header = next(reader)
                field_to_index = {f: header.index(f) for f in fields.keys()}
                make_example = partial(make_example, field_to_index=field_to_index)

            if skip_header:
                next(reader)
            
            examples = [make_example(line, fields) for line in reader]
            bert_data = []
            for d in examples:
                question = getattr(d, 'question')
                question = question[-(min(len(question), question_fix_length)):]
                cat_d = []
                for name in bert_fields[:-1]:
                    cat_d.append([101] + question + [102] + getattr(d, name[0]) )
                    
                bert_data.append(cat_d)
       
            _ = [a.append(l.label) for a, l in zip(bert_data, examples)]
            bert_examples = [make_bert_example(data, bert_fields) for data in bert_data]
            
        #  not deal with bert situation
        if isinstance(fields, dict):
            fields, field_dict = [], fields
            for field in field_dict.values():
                if isinstance(field, list):
                    fields.extend(field)
                else:
                    fields.append(field)

        super(BertTabularDataset_MultipleChoice, self).__init__(bert_examples, bert_fields, **kwargs)
예제 #3
0
파일: iterable_train.py 프로젝트: oja/text
def count(data_path):
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        labels = [int(row[0]) for row in reader]
        num_lines = len(labels)
        num_labels = len(set(labels))
        return num_labels, num_lines
예제 #4
0
 def __init__(self,
              path: str,
              fields: Sequence[Tuple[str, Field]],
              num_samples: Optional[int] = None,
              add_cls: bool = False,
              random_state: int = 162,
              max_len: Optional[int] = None,
              verbose: bool = True,
              **kwargs):
     duplicate_spaces_re = re.compile(r' +')
     with open(path, 'r', encoding='utf-8') as fp:
         all_data = []
         reader = unicode_csv_reader(fp)
         for row in reader:
             cls, text = row[0], row[1]
             if max_len is not None and len(text.split()) > max_len:
                 continue
             text = text.replace('\\n\\n', '\\n ')
             text = duplicate_spaces_re.sub(' ', text)
             data = (text, text, cls) if add_cls else (text, text)
             all_data.append(data)
     if num_samples is not None and num_samples < len(all_data):
         random.seed(random_state)
         all_data = random.sample(all_data, num_samples)
     examples = []
     for data in tqdm(all_data,
                      desc='Converting data into examples',
                      disable=not verbose):
         examples.append(Example.fromlist(data=data, fields=fields))
     super().__init__(examples=examples, fields=fields, **kwargs)
예제 #5
0
def fields(id_attr, label_attr, data_dir, train_filename):
    ignore_columns = ["left_id", "right_id"]
    with io.open(os.path.expanduser(os.path.join(data_dir, train_filename)),
                 encoding="utf8") as f:
        header = next(unicode_csv_reader(f))
    return _make_fields(header, id_attr, label_attr, ignore_columns, True,
                        "nltk", False)
예제 #6
0
def test_make_fields_1():
    path = os.path.join(test_dir_path, "test_datasets")
    a_dataset = "sample_table_large.csv"
    with io.open(os.path.expanduser(os.path.join(path, a_dataset)),
                 encoding="utf8") as f:
        header = next(unicode_csv_reader(f))
    assert header == [
        "_id",
        "ltable_id",
        "rtable_id",
        "label",
        "ltable_Song_Name",
        "ltable_Artist_Name",
        "ltable_Price",
        "ltable_Released",
        "rtable_Song_Name",
        "rtable_Artist_Name",
        "rtable_Price",
        "rtable_Released",
    ]
    id_attr = "_id"
    label_attr = "label"
    fields = _make_fields(header, id_attr, label_attr,
                          ["ltable_id", "rtable_id"], True, "nltk", True)
    assert len(fields) == 12
    counter = {}
    for tup in fields:
        if tup[1] not in counter:
            counter[tup[1]] = 0
        counter[tup[1]] += 1
    assert sorted(list(counter.values())) == [1, 1, 2, 8]
예제 #7
0
def csv_iterator(data_path, ngrams):
    tokenizer = get_tokenizer("basic_english")
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            tokens = ' '.join(row[1:])
            yield ngrams_iterator(tokenizer(tokens), ngrams)
예제 #8
0
파일: dataset.py 프로젝트: lzzppp/DERT
    def __init__(self,
                 fields,
                 column_naming,
                 path=None,
                 format='csv',
                 examples=None,
                 metadata=None,
                 **kwargs):
        if examples is None:
            make_example = {
                'json': Example.fromJSON,
                'dict': Example.fromdict,
                'tsv': Example.fromCSV,
                'csv': Example.fromCSV
            }[format.lower()]

            lines = 0
            with open(os.path.expanduser(path), encoding="utf8") as f:
                for line in f:
                    lines += 1

            with open(os.path.expanduser(path), encoding="utf8") as f:
                if format == 'csv':
                    reader = unicode_csv_reader(f)
                elif format == 'tsv':
                    reader = unicode_csv_reader(f, delimiter='\t')
                else:
                    reader = f

                next(reader)
                examples = [
                    make_example(line, fields) for line in pyprind.prog_bar(
                        reader,
                        iterations=lines,
                        title='\nReading and processing data from "' + path +
                        '"')
                ]

            super(MatchingDataset, self).__init__(examples, fields, **kwargs)
        else:
            self.fields = dict(fields)
            self.examples = examples
            self.metadata = metadata

        self.path = path
        self.column_naming = column_naming
        self._set_attributes()
예제 #9
0
    def __init__(self, path, format, fields, skip_header=False, **kwargs):
        """Create a DuplicateTabularDataset given a path, file format, and field list.
        A field in the example file can be mapped to mutiple field objects, breaking functional constraint.
        Arguments:
            path (str): Path to the data file.
            format (str): The format of the data file. One of "CSV", "TSV", or
                "JSON" (case-insensitive).
            fields dict[str: tuple(tuple(str, Field), ...)]:
                If using a dict, the keys should be a subset of the JSON keys or CSV/TSV
                columns, and the values should be tuples of (name, field).
                Keys not present in the input dictionary are ignored.
                This allows the user to rename columns from their JSON/CSV/TSV key names
                and also enables selecting a subset of columns to load.
            skip_header (bool): Whether to skip the first line of the input file.
        """
        make_example = {
            'json': Example.fromJSON, 'dict': Example.fromdict,
            'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format.lower()]

        print(path)

        with io.open(os.path.expanduser(path), encoding="utf8") as f:
            if format == 'csv':
                reader = unicode_csv_reader(f)
            elif format == 'tsv':
                reader = unicode_csv_reader(f, delimiter='\t')
            else:
                reader = f

            if format in ['csv', 'tsv']:
                if skip_header:
                    raise ValueError('When using a dict to specify fields with a {} file,'
                                     'skip_header must be False and'
                                     'the file must have a header.'.format(format))
                header = next(reader)
                print(header)
                print(fields.keys())
                field_to_index = {f: header.index(f) for f in fields.keys()}
                make_example = partial(make_example, field_to_index=field_to_index)

            if skip_header:
                next(reader)

            examples = [make_example(line, fields) for line in reader]

            super(DuplicateTabularDataset, self).__init__(examples, fields, **kwargs)
예제 #10
0
파일: iterable_train.py 프로젝트: oja/text
 def iterator(start, num_lines):
     tokenizer = get_tokenizer("basic_english")
     with io.open(data_path, encoding="utf8") as f:
         reader = unicode_csv_reader(f)
         for i, row in enumerate(reader):
             if i == start:
                 break
         for _ in range(num_lines):
             tokens = ' '.join(row[1:])
             tokens = ngrams_iterator(tokenizer(tokens), ngrams)
             yield int(row[0]) - 1, torch.tensor(
                 [vocab[token] for token in tokens])
             try:
                 row = next(reader)
             except StopIteration:
                 f.seek(0)
                 reader = unicode_csv_reader(f)
                 row = next(reader)
예제 #11
0
def count(data_path):
    r"""
    return the total numerber of text entries and labels.
    """
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        labels = [int(row[0]) for row in reader]
        num_lines = len(labels)
        num_labels = len(set(labels))
        return num_labels, num_lines
예제 #12
0
def _csv_iterator(data_path, tokenizer, ngrams, yield_cls=False):
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            tokens = ' '.join(row[1:])
            tokens = tokenizer(tokens)
            if yield_cls:
                yield int(row[0]) - 1, ngrams_iterator(tokens, ngrams)
            else:
                yield ngrams_iterator(tokens, ngrams)
예제 #13
0
def process_unlabeled(path, trained_model, ignore_columns=None):
    """Creates a dataset object for an unlabeled dataset.

    Args:
        path (string):
            The full path to the unlabeled data file (not just the directory).
        trained_model (:class:`~deepmatcher.MatchingModel`):
            The trained model. The model is aware of the configuration of the training
            data on which it was trained, and so this method reuses the same
            configuration for the unlabeled data.
        ignore_columns (list):
            A list of columns to ignore in the unlabeled CSV file.
    """
    with io.open(path, encoding="utf8") as f:
        header = next(unicode_csv_reader(f))

    train_info = trained_model.meta
    if ignore_columns is None:
        ignore_columns = train_info.ignore_columns
    column_naming = dict(train_info.column_naming)
    column_naming['label'] = None

    fields = _make_fields(header, column_naming['id'], column_naming['label'],
                          ignore_columns, train_info.lowercase,
                          train_info.tokenize, train_info.include_lengths)

    begin = timer()
    dataset_args = {'fields': fields, 'column_naming': column_naming}
    dataset = MatchingDataset(path=path, **dataset_args)

    # Make sure we have the same attributes.
    assert set(dataset.all_text_fields) == set(train_info.all_text_fields)

    after_load = timer()
    logger.info('Data load time: {}s'.format(after_load - begin))

    reverse_fields_dict = dict((pair[1], pair[0]) for pair in fields)
    for field, name in reverse_fields_dict.items():
        if field is not None and field.use_vocab:
            # Copy over vocab from original train data.
            field.vocab = copy.deepcopy(train_info.vocabs[name])
            # Then extend the vocab.
            field.extend_vocab(dataset,
                               vectors=train_info.embeddings,
                               cache=train_info.embeddings_cache)

    dataset.vocabs = {
        name: dataset.fields[name].vocab
        for name in train_info.all_text_fields
    }

    after_vocab = timer()
    logger.info('Vocab update time: {}s'.format(after_vocab - after_load))

    return dataset
예제 #14
0
def test_check_header_1():
    path = os.path.join(test_dir_path, "test_datasets")
    a_dataset = "sample_table_small.csv"
    with io.open(os.path.expanduser(os.path.join(path, a_dataset)),
                 encoding="utf8") as f:
        header = next(unicode_csv_reader(f))
    assert header == ["id", "left_a", "right_a", "label"]
    id_attr = "id"
    label_attr = "label"
    left_prefix = "left"
    right_prefix = "right"
    _check_header(header, id_attr, left_prefix, right_prefix, label_attr, [])
예제 #15
0
def _csv_iterator(data_path, ngrams, yield_cls=False):
    tokenizer = get_tokenizer("basic_english")
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            tokens = row[1]
            tokens = tokenizer(tokens)
            if yield_cls:
                yield 1 if int(row[4]) == 3 else 0, ngrams_iterator(
                    tokens, ngrams)
            else:
                yield ngrams_iterator(tokens, ngrams)
예제 #16
0
def _csv_iterator(data_path, ngrams, dataset_name=None, yield_cls=False):
    # tokenizer = get_tokenizer("basic_english")
    tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f, delimiter='\t')
        for row in reader:
            tokens = row[1]
            tokens = tokenizer(tokens)
            if yield_cls:
                label = int(LABELS[dataset_name][row[0]]) - 1
                yield label, ngrams_iterator(tokens, ngrams)
            else:
                yield ngrams_iterator(tokens, ngrams)
예제 #17
0
 def test_check_header_5(self):
     path = os.path.join(test_dir_path, 'test_datasets')
     a_dataset = 'sample_table_small.csv'
     with io.open(os.path.expanduser(os.path.join(path, a_dataset)),
                  encoding="utf8") as f:
         header = next(unicode_csv_reader(f))
     self.assertEqual(header, ['id', 'left_a', 'right_a', 'label'])
     id_attr = 'id'
     label_attr = ''
     left_prefix = 'left'
     right_prefix = 'right'
     _check_header(header, id_attr, left_prefix, right_prefix, label_attr,
                   [])
예제 #18
0
def _create_data_with_sp_transform(sp_generator, data_path):

    data = []
    labels = []
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            corpus = ' '.join(row[1:])
            token_ids = list(sp_generator([corpus]))[0]
            label = int(row[0]) - 1
            data.append((label, torch.tensor(token_ids)))
            labels.append(label)
    return data, set(labels)
예제 #19
0
def _csv_iterator(data_path, ngrams, yield_cls=False, label=-1):
    tokenizer = get_tokenizer("basic_english")
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f, delimiter="\t")
        for row in reader:
            tokens = ' '.join([row[5]])
            #print(row[5])
            tokens = tokenizer(tokens)

            if yield_cls:
                yield row[7], ngrams_iterator(tokens, ngrams)
            else:
                yield ngrams_iterator(tokens, ngrams)
예제 #20
0
    def iterator(start, num_lines):
        tokenizer = get_tokenizer("basic_english")
        with io.open(data_path, encoding="utf8") as f:
            reader = unicode_csv_reader(f)
            for i, row in enumerate(reader):
                if i == start:
                    break
            for _ in range(num_lines):
                tokens = ' '.join(row[1:])
                tokens = ngrams_iterator(tokenizer(tokens), ngrams)

                label_onehot = [0.0 for _ in range(20)]
                for classNum in row[0].split(' '):
                    label_onehot[int(classNum)] = 1.0

                yield label_onehot, torch.tensor(
                    [vocab[token] for token in tokens])
                try:
                    row = next(reader)
                except StopIteration:
                    f.seek(0)
                    reader = unicode_csv_reader(f)
                    row = next(reader)
예제 #21
0
def _create_data_with_sp_transform(data_path):

    data = []
    labels = []
    spm_path = pretrained_sp_model['text_unigram_15000']
    text_pipeline = sentencepiece_processor(download_from_url(spm_path))
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            corpus = ' '.join(row[1:])
            token_ids = text_pipeline(corpus)
            label = int(row[0]) - 1
            data.append((label, torch.tensor(token_ids)))
            labels.append(label)
    return data, set(labels)
예제 #22
0
def csv_iterator(data_path, ngrams, yield_cls=False):
    """
    加载csv文本文件,并根据原始文本 生成 指定ngram语法的 词汇(token)样本
    Args:
        data_path:
        ngrams:
        yield_cls:

    Returns:

    """
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            tokens = ' '.join(row[1:])
            tokens = tokenizer(tokens)
            if yield_cls:
                yield int(row[0]) - 1, ngrams_iterator(tokens, ngrams)
            else:
                yield ngrams_iterator(tokens, ngrams)
예제 #23
0
    def test_text_nomalize_function(self):
        # Test text_nomalize function in torchtext.datasets.text_classification
        ref_lines = []
        test_lines = []

        tokenizer = data.get_tokenizer("basic_english")
        data_path = 'test/asset/text_normalization_ag_news_test.csv'
        with io.open(data_path, encoding="utf8") as f:
            reader = unicode_csv_reader(f)
            for row in reader:
                test_lines.append(tokenizer(' , '.join(row)))

        data_path = 'test/asset/text_normalization_ag_news_ref_results.test'
        with io.open(data_path, encoding="utf8") as ref_data:
            for line in ref_data:
                line = line.split()
                self.assertEqual(line[0][:9], '__label__')
                line[0] = line[0][9:]  # remove '__label__'
                ref_lines.append(line)

        self.assertEqual(ref_lines, test_lines)
예제 #24
0
 def test_make_fields_1(self):
     path = os.path.join(test_dir_path, 'test_datasets')
     a_dataset = 'sample_table_large.csv'
     with io.open(os.path.expanduser(os.path.join(path, a_dataset)),
                  encoding="utf8") as f:
         header = next(unicode_csv_reader(f))
     self.assertEqual(header, [
         '_id', 'ltable_id', 'rtable_id', 'label', 'ltable_Song_Name',
         'ltable_Artist_Name', 'ltable_Price', 'ltable_Released',
         'rtable_Song_Name', 'rtable_Artist_Name', 'rtable_Price',
         'rtable_Released'
     ])
     id_attr = '_id'
     label_attr = 'label'
     fields = _make_fields(header, id_attr, label_attr,
                           ['ltable_id', 'rtable_id'], True, 'nltk', True)
     self.assertEqual(len(fields), 12)
     counter = {}
     for tup in fields:
         if tup[1] not in counter:
             counter[tup[1]] = 0
         counter[tup[1]] += 1
     self.assertEqual(sorted(list(counter.values())), [1, 1, 2, 8])
예제 #25
0
    def setUp(self):
        self.data_dir = os.path.join(test_dir_path, 'test_datasets')
        self.train = 'test_train.csv'
        self.validation = 'test_valid.csv'
        self.test = 'test_test.csv'
        self.cache_name = 'test_cacheddata.pth'
        with io.open(os.path.expanduser(os.path.join(self.data_dir,
                                                     self.train)),
                     encoding="utf8") as f:
            header = next(unicode_csv_reader(f))

        id_attr = 'id'
        label_attr = 'label'
        ignore_columns = ['left_id', 'right_id']
        self.fields = _make_fields(header, id_attr, label_attr, ignore_columns,
                                   True, 'nltk', False)

        self.column_naming = {
            'id': id_attr,
            'left': 'left_',
            'right': 'right_',
            'label': label_attr
        }
예제 #26
0
def _csv_iterator(data_path,
                  ngrams,
                  skip_header=True,
                  yield_cls=False,
                  label_col=6,
                  token_col=[1, 5],
                  label_mapping={
                      "simulation": 0,
                      "hardware": 1,
                      "edge_computing": 2
                  }):
    tokenizer = get_tokenizer("spacy", "en_core_web_sm")
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        if skip_header:
            next(reader, None)
        for row in reader:
            tokens = ' '.join([j for i, j in enumerate(row) if i in token_col])
            tokens = tokenizer(tokens)
            if yield_cls:
                yield label_mapping[row[label_col]], ngrams_iterator(
                    tokens, ngrams)
            else:
                yield ngrams_iterator(tokens, ngrams)
예제 #27
0
def _create_data_from_csv(data_path):
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            yield int(row[0]), ' '.join(row[1:])
예제 #28
0
def _read_text_iterator(path):
    with io.open(path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            yield " ".join(row)
예제 #29
0
def process(
    path,
    train=None,
    validation=None,
    test=None,
    cache="cacheddata.pth",
    check_cached_data=True,
    auto_rebuild_cache=True,
    tokenize="nltk",
    lowercase=True,
    embeddings="fasttext.en.bin",
    embeddings_cache_path="~/.vector_cache",
    ignore_columns=(),
    include_lengths=True,
    id_attr="id",
    label_attr="label",
    left_prefix="left_",
    right_prefix="right_",
    use_magellan_convention=False,
    pca=True,
):
    r"""Creates dataset objects for multiple splits of a dataset.

    This involves the following steps (if data cannot be retrieved from the cache):
    #. Read CSV header of a data file and verify header is sane.
    #. Create fields, i.e., column processing specifications (e.g. tokenization, label
        conversion to integers etc.)
    #. Load each data file:
        #. Read each example (tuple pair) in specified CSV file.
        #. Preprocess example. Involves lowercasing and tokenization (unless disabled).
        #. Compute metadata if training data file. \
            See :meth:`MatchingDataset.compute_metadata` for details.
    #. Create vocabulary consisting of all tokens in all attributes in all datasets.
    #. Download word embedding data if necessary.
    #. Create mapping from each word in vocabulary to its word embedding.
    #. Compute metadata
    #. Write to cache

    Arguments:
        path (str): Common prefix of the splits' file paths.
        train (str): Suffix to add to path for the train set.
        validation (str): Suffix to add to path for the validation set, or None
            for no validation set. Default is None.
        test (str): Suffix to add to path for the test set, or None for no test
            set. Default is None.
        cache (str): Suffix to add to path for cache file. If `None` disables caching.
        check_cached_data (bool): Verify that data files haven't changes since the
            cache was constructed and that relevant field options haven't changed.
        auto_rebuild_cache (bool): Automatically rebuild the cache if the data files
            are modified or if the field options change. Defaults to False.
        tokenize (str): Which tokenizer to use
        lowercase (bool): Whether to lowercase all words in all attributes.
        embeddings (str or list): One or more of the following strings:
            * `fasttext.{lang}.bin`:
                This uses sub-word level word embeddings based on binary models from "wiki
                word vectors" released by FastText. {lang} is 'en' or any other 2 letter
                ISO 639-1 Language Code, or 3 letter ISO 639-2 Code, if the language does
                not have a 2 letter code. 300d vectors.
                ``fasttext.en.bin`` is the default.
            * `fasttext.wiki.vec`:
                Uses wiki news word vectors released as part of "Advances in Pre-Training
                Distributed Word Representations" by Mikolov et al. (2018). 300d vectors.
            * `fasttext.crawl.vec`:
                Uses Common Crawl word vectors released as part of "Advances in
                Pre-Training Distributed Word Representations" by Mikolov et al. (2018).
                300d vectors.
            * `glove.6B.{dims}`:
                Uses uncased Glove trained on Wiki + Gigaword. {dims} is one of (50d,
                100d, 200d, or 300d).
            * `glove.42B.300d`:
                Uses uncased Glove trained on Common Crawl. 300d vectors.
            * `glove.840B.300d`:
                Uses cased Glove trained on Common Crawl. 300d vectors.
            * `glove.twitter.27B.{dims}`:
                Uses cased Glove trained on Twitter. {dims} is one of (25d, 50d, 100d, or
                200d).
        embeddings_cache_path (str): Directory to store dowloaded word vector data.
        ignore_columns (list): A list of columns to ignore in the CSV files.
        include_lengths (bool): Whether to provide the model with the lengths of
            each attribute sequence in each batch. If True, length information can be
            used by the neural network, e.g. when picking the last RNN output of each
            attribute sequence.
        id_attr (str): The name of the tuple pair ID column in the CSV file.
        label_attr (str): The name of the tuple pair match label column in the CSV file.
        left_prefix (str): The prefix for attribute names belonging to the left table.
        right_prefix (str): The prefix for attribute names belonging to the right table.
        use_magellan_convention (bool): Set `id_attr`, `left_prefix`, and `right_prefix`
            according to Magellan (py_entitymatching Python package) naming conventions.
            Specifically, set them to be '_id', 'ltable_', and 'rtable_' respectively.
        pca (bool): Whether to compute PCA for each attribute (needed for SIF model).
            Defaults to False.

    Returns:
        Tuple[MatchingDataset]: Datasets for (train, validation, and test) splits in that
            order, if provided, or dataset for unlabeled, if provided.

    """
    if use_magellan_convention:
        id_attr = "_id"
        left_prefix = "ltable_"
        right_prefix = "rtable_"

    # TODO(Sid): check for all datasets to make sure the files exist and have the same schema
    a_dataset = train or validation or test
    with io.open(os.path.expanduser(os.path.join(path, a_dataset)),
                 encoding="utf8") as f:
        header = next(unicode_csv_reader(f))

    _maybe_download_nltk_data()
    _check_header(header, id_attr, left_prefix, right_prefix, label_attr,
                  ignore_columns)
    fields = _make_fields(
        header,
        id_attr,
        label_attr,
        ignore_columns,
        lowercase,
        tokenize,
        include_lengths,
    )

    column_naming = {
        "id": id_attr,
        "left": left_prefix,
        "right": right_prefix,
        "label": label_attr,
    }

    datasets = MatchingDataset.splits(
        path,
        train,
        validation,
        test,
        fields,
        embeddings,
        embeddings_cache_path,
        column_naming,
        cache,
        check_cached_data,
        auto_rebuild_cache,
        train_pca=pca,
    )

    # Save additional information to train dataset.
    datasets[0].ignore_columns = ignore_columns
    datasets[0].tokenize = tokenize
    datasets[0].lowercase = lowercase
    datasets[0].include_lengths = include_lengths

    return datasets
예제 #30
0
    def __init__(self, path, format, fields, skip_header=False,
                 csv_reader_params=None, **kwargs):
        """Create a TabularDataset given a path, file format, and field list.

        Args:
            path (str): Path to the data file.
            format (str): The format of the data file. One of "CSV", "TSV", or
                "JSON" (case-insensitive).
            fields ((list(tuple(str, Field)) or dict[str: tuple(str, Field)): If using a list,
                the format must be CSV or TSV, and the values of the list
                should be tuples of (name, field).
                The fields should be in the same order as the columns in the CSV or TSV
                file, while tuples of (name, None) represent columns that will be ignored.

                If using a dict, the keys should be a subset of the JSON keys or CSV/TSV
                columns, and the values should be tuples of (name, field).
                Keys not present in the input dictionary are ignored.
                This allows the user to rename columns from their JSON/CSV/TSV key names
                and also enables selecting a subset of columns to load.
            skip_header (bool): Whether to skip the first line of the input file.
            csv_reader_params(dict): Parameters to pass to the csv reader.
                Only relevant when format is csv or tsv.
                See
                https://docs.python.org/3/library/csv.html#csv.reader
                for more details.
            kwargs (dict): passed to the Dataset parent class.
        """
        if csv_reader_params is None:
            csv_reader_params = {}
        format = format.lower()
        make_example = {
            'json': Example.fromJSON, 'dict': Example.fromdict,
            'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format]

        with io.open(os.path.expanduser(path), encoding="utf8") as f:
            if format == 'csv':
                reader = unicode_csv_reader(f, **csv_reader_params)
            elif format == 'tsv':
                reader = unicode_csv_reader(f, delimiter='\t', **csv_reader_params)
            else:
                reader = f

            if format in ['csv', 'tsv'] and isinstance(fields, dict):
                if skip_header:
                    raise ValueError('When using a dict to specify fields with a {} file,'
                                     'skip_header must be False and'
                                     'the file must have a header.'.format(format))
                header = next(reader)
                field_to_index = {f: header.index(f) for f in fields.keys()}
                make_example = partial(make_example, field_to_index=field_to_index)

            if skip_header:
                next(reader)

            examples = [make_example(line, fields) for line in reader]

        if isinstance(fields, dict):
            fields, field_dict = [], fields
            for field in field_dict.values():
                if isinstance(field, list):
                    fields.extend(field)
                else:
                    fields.append(field)

        super(TabularDataset, self).__init__(examples, fields, **kwargs)