def read_examples_from_file(fields, format: str, path): make_example = { 'json': Example.fromJSON, 'dict': Example.fromdict, 'tsv': Example.fromCSV, 'csv': Example.fromCSV }[format.lower()] lines = 0 with open(os.path.expanduser(path), encoding="utf8") as f: for line in f: lines += 1 with open(os.path.expanduser(path), encoding="utf8") as f: if format == 'csv': reader = unicode_csv_reader(f) elif format == 'tsv': reader = unicode_csv_reader(f, delimiter='\t') else: reader = f next(reader) examples = [ make_example(line, fields) for line in pyprind.prog_bar( reader, iterations=lines, title='\nReading and processing data from "' + path + '"') ] return examples
def __init__(self, path, format, question_fix_length, fields, bert_fields, skip_header=False, csv_reader_params={}, **kwargs): format = format.lower() make_example = { 'json': Example.fromJSON, 'dict': Example.fromdict, 'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format] make_bert_example = Example.fromlist with io.open(os.path.expanduser(path), encoding="utf8") as f: if format == 'csv': reader = unicode_csv_reader(f, **csv_reader_params) elif format == 'tsv': reader = unicode_csv_reader(f, delimiter='\t', **csv_reader_params) else: reader = f if format in ['csv', 'tsv'] and isinstance(fields, dict): if skip_header: raise ValueError('When using a dict to specify fields with a {} file,' 'skip_header must be False and' 'the file must have a header.'.format(format)) header = next(reader) field_to_index = {f: header.index(f) for f in fields.keys()} make_example = partial(make_example, field_to_index=field_to_index) if skip_header: next(reader) examples = [make_example(line, fields) for line in reader] bert_data = [] for d in examples: question = getattr(d, 'question') question = question[-(min(len(question), question_fix_length)):] cat_d = [] for name in bert_fields[:-1]: cat_d.append([101] + question + [102] + getattr(d, name[0]) ) bert_data.append(cat_d) _ = [a.append(l.label) for a, l in zip(bert_data, examples)] bert_examples = [make_bert_example(data, bert_fields) for data in bert_data] # not deal with bert situation if isinstance(fields, dict): fields, field_dict = [], fields for field in field_dict.values(): if isinstance(field, list): fields.extend(field) else: fields.append(field) super(BertTabularDataset_MultipleChoice, self).__init__(bert_examples, bert_fields, **kwargs)
def count(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) labels = [int(row[0]) for row in reader] num_lines = len(labels) num_labels = len(set(labels)) return num_labels, num_lines
def __init__(self, path: str, fields: Sequence[Tuple[str, Field]], num_samples: Optional[int] = None, add_cls: bool = False, random_state: int = 162, max_len: Optional[int] = None, verbose: bool = True, **kwargs): duplicate_spaces_re = re.compile(r' +') with open(path, 'r', encoding='utf-8') as fp: all_data = [] reader = unicode_csv_reader(fp) for row in reader: cls, text = row[0], row[1] if max_len is not None and len(text.split()) > max_len: continue text = text.replace('\\n\\n', '\\n ') text = duplicate_spaces_re.sub(' ', text) data = (text, text, cls) if add_cls else (text, text) all_data.append(data) if num_samples is not None and num_samples < len(all_data): random.seed(random_state) all_data = random.sample(all_data, num_samples) examples = [] for data in tqdm(all_data, desc='Converting data into examples', disable=not verbose): examples.append(Example.fromlist(data=data, fields=fields)) super().__init__(examples=examples, fields=fields, **kwargs)
def fields(id_attr, label_attr, data_dir, train_filename): ignore_columns = ["left_id", "right_id"] with io.open(os.path.expanduser(os.path.join(data_dir, train_filename)), encoding="utf8") as f: header = next(unicode_csv_reader(f)) return _make_fields(header, id_attr, label_attr, ignore_columns, True, "nltk", False)
def test_make_fields_1(): path = os.path.join(test_dir_path, "test_datasets") a_dataset = "sample_table_large.csv" with io.open(os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: header = next(unicode_csv_reader(f)) assert header == [ "_id", "ltable_id", "rtable_id", "label", "ltable_Song_Name", "ltable_Artist_Name", "ltable_Price", "ltable_Released", "rtable_Song_Name", "rtable_Artist_Name", "rtable_Price", "rtable_Released", ] id_attr = "_id" label_attr = "label" fields = _make_fields(header, id_attr, label_attr, ["ltable_id", "rtable_id"], True, "nltk", True) assert len(fields) == 12 counter = {} for tup in fields: if tup[1] not in counter: counter[tup[1]] = 0 counter[tup[1]] += 1 assert sorted(list(counter.values())) == [1, 1, 2, 8]
def csv_iterator(data_path, ngrams): tokenizer = get_tokenizer("basic_english") with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: tokens = ' '.join(row[1:]) yield ngrams_iterator(tokenizer(tokens), ngrams)
def __init__(self, fields, column_naming, path=None, format='csv', examples=None, metadata=None, **kwargs): if examples is None: make_example = { 'json': Example.fromJSON, 'dict': Example.fromdict, 'tsv': Example.fromCSV, 'csv': Example.fromCSV }[format.lower()] lines = 0 with open(os.path.expanduser(path), encoding="utf8") as f: for line in f: lines += 1 with open(os.path.expanduser(path), encoding="utf8") as f: if format == 'csv': reader = unicode_csv_reader(f) elif format == 'tsv': reader = unicode_csv_reader(f, delimiter='\t') else: reader = f next(reader) examples = [ make_example(line, fields) for line in pyprind.prog_bar( reader, iterations=lines, title='\nReading and processing data from "' + path + '"') ] super(MatchingDataset, self).__init__(examples, fields, **kwargs) else: self.fields = dict(fields) self.examples = examples self.metadata = metadata self.path = path self.column_naming = column_naming self._set_attributes()
def __init__(self, path, format, fields, skip_header=False, **kwargs): """Create a DuplicateTabularDataset given a path, file format, and field list. A field in the example file can be mapped to mutiple field objects, breaking functional constraint. Arguments: path (str): Path to the data file. format (str): The format of the data file. One of "CSV", "TSV", or "JSON" (case-insensitive). fields dict[str: tuple(tuple(str, Field), ...)]: If using a dict, the keys should be a subset of the JSON keys or CSV/TSV columns, and the values should be tuples of (name, field). Keys not present in the input dictionary are ignored. This allows the user to rename columns from their JSON/CSV/TSV key names and also enables selecting a subset of columns to load. skip_header (bool): Whether to skip the first line of the input file. """ make_example = { 'json': Example.fromJSON, 'dict': Example.fromdict, 'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format.lower()] print(path) with io.open(os.path.expanduser(path), encoding="utf8") as f: if format == 'csv': reader = unicode_csv_reader(f) elif format == 'tsv': reader = unicode_csv_reader(f, delimiter='\t') else: reader = f if format in ['csv', 'tsv']: if skip_header: raise ValueError('When using a dict to specify fields with a {} file,' 'skip_header must be False and' 'the file must have a header.'.format(format)) header = next(reader) print(header) print(fields.keys()) field_to_index = {f: header.index(f) for f in fields.keys()} make_example = partial(make_example, field_to_index=field_to_index) if skip_header: next(reader) examples = [make_example(line, fields) for line in reader] super(DuplicateTabularDataset, self).__init__(examples, fields, **kwargs)
def iterator(start, num_lines): tokenizer = get_tokenizer("basic_english") with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for i, row in enumerate(reader): if i == start: break for _ in range(num_lines): tokens = ' '.join(row[1:]) tokens = ngrams_iterator(tokenizer(tokens), ngrams) yield int(row[0]) - 1, torch.tensor( [vocab[token] for token in tokens]) try: row = next(reader) except StopIteration: f.seek(0) reader = unicode_csv_reader(f) row = next(reader)
def count(data_path): r""" return the total numerber of text entries and labels. """ with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) labels = [int(row[0]) for row in reader] num_lines = len(labels) num_labels = len(set(labels)) return num_labels, num_lines
def _csv_iterator(data_path, tokenizer, ngrams, yield_cls=False): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: tokens = ' '.join(row[1:]) tokens = tokenizer(tokens) if yield_cls: yield int(row[0]) - 1, ngrams_iterator(tokens, ngrams) else: yield ngrams_iterator(tokens, ngrams)
def process_unlabeled(path, trained_model, ignore_columns=None): """Creates a dataset object for an unlabeled dataset. Args: path (string): The full path to the unlabeled data file (not just the directory). trained_model (:class:`~deepmatcher.MatchingModel`): The trained model. The model is aware of the configuration of the training data on which it was trained, and so this method reuses the same configuration for the unlabeled data. ignore_columns (list): A list of columns to ignore in the unlabeled CSV file. """ with io.open(path, encoding="utf8") as f: header = next(unicode_csv_reader(f)) train_info = trained_model.meta if ignore_columns is None: ignore_columns = train_info.ignore_columns column_naming = dict(train_info.column_naming) column_naming['label'] = None fields = _make_fields(header, column_naming['id'], column_naming['label'], ignore_columns, train_info.lowercase, train_info.tokenize, train_info.include_lengths) begin = timer() dataset_args = {'fields': fields, 'column_naming': column_naming} dataset = MatchingDataset(path=path, **dataset_args) # Make sure we have the same attributes. assert set(dataset.all_text_fields) == set(train_info.all_text_fields) after_load = timer() logger.info('Data load time: {}s'.format(after_load - begin)) reverse_fields_dict = dict((pair[1], pair[0]) for pair in fields) for field, name in reverse_fields_dict.items(): if field is not None and field.use_vocab: # Copy over vocab from original train data. field.vocab = copy.deepcopy(train_info.vocabs[name]) # Then extend the vocab. field.extend_vocab(dataset, vectors=train_info.embeddings, cache=train_info.embeddings_cache) dataset.vocabs = { name: dataset.fields[name].vocab for name in train_info.all_text_fields } after_vocab = timer() logger.info('Vocab update time: {}s'.format(after_vocab - after_load)) return dataset
def test_check_header_1(): path = os.path.join(test_dir_path, "test_datasets") a_dataset = "sample_table_small.csv" with io.open(os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: header = next(unicode_csv_reader(f)) assert header == ["id", "left_a", "right_a", "label"] id_attr = "id" label_attr = "label" left_prefix = "left" right_prefix = "right" _check_header(header, id_attr, left_prefix, right_prefix, label_attr, [])
def _csv_iterator(data_path, ngrams, yield_cls=False): tokenizer = get_tokenizer("basic_english") with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: tokens = row[1] tokens = tokenizer(tokens) if yield_cls: yield 1 if int(row[4]) == 3 else 0, ngrams_iterator( tokens, ngrams) else: yield ngrams_iterator(tokens, ngrams)
def _csv_iterator(data_path, ngrams, dataset_name=None, yield_cls=False): # tokenizer = get_tokenizer("basic_english") tokenizer = get_tokenizer("spacy", language="en_core_web_sm") with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f, delimiter='\t') for row in reader: tokens = row[1] tokens = tokenizer(tokens) if yield_cls: label = int(LABELS[dataset_name][row[0]]) - 1 yield label, ngrams_iterator(tokens, ngrams) else: yield ngrams_iterator(tokens, ngrams)
def test_check_header_5(self): path = os.path.join(test_dir_path, 'test_datasets') a_dataset = 'sample_table_small.csv' with io.open(os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: header = next(unicode_csv_reader(f)) self.assertEqual(header, ['id', 'left_a', 'right_a', 'label']) id_attr = 'id' label_attr = '' left_prefix = 'left' right_prefix = 'right' _check_header(header, id_attr, left_prefix, right_prefix, label_attr, [])
def _create_data_with_sp_transform(sp_generator, data_path): data = [] labels = [] with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: corpus = ' '.join(row[1:]) token_ids = list(sp_generator([corpus]))[0] label = int(row[0]) - 1 data.append((label, torch.tensor(token_ids))) labels.append(label) return data, set(labels)
def _csv_iterator(data_path, ngrams, yield_cls=False, label=-1): tokenizer = get_tokenizer("basic_english") with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f, delimiter="\t") for row in reader: tokens = ' '.join([row[5]]) #print(row[5]) tokens = tokenizer(tokens) if yield_cls: yield row[7], ngrams_iterator(tokens, ngrams) else: yield ngrams_iterator(tokens, ngrams)
def iterator(start, num_lines): tokenizer = get_tokenizer("basic_english") with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for i, row in enumerate(reader): if i == start: break for _ in range(num_lines): tokens = ' '.join(row[1:]) tokens = ngrams_iterator(tokenizer(tokens), ngrams) label_onehot = [0.0 for _ in range(20)] for classNum in row[0].split(' '): label_onehot[int(classNum)] = 1.0 yield label_onehot, torch.tensor( [vocab[token] for token in tokens]) try: row = next(reader) except StopIteration: f.seek(0) reader = unicode_csv_reader(f) row = next(reader)
def _create_data_with_sp_transform(data_path): data = [] labels = [] spm_path = pretrained_sp_model['text_unigram_15000'] text_pipeline = sentencepiece_processor(download_from_url(spm_path)) with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: corpus = ' '.join(row[1:]) token_ids = text_pipeline(corpus) label = int(row[0]) - 1 data.append((label, torch.tensor(token_ids))) labels.append(label) return data, set(labels)
def csv_iterator(data_path, ngrams, yield_cls=False): """ 加载csv文本文件,并根据原始文本 生成 指定ngram语法的 词汇(token)样本 Args: data_path: ngrams: yield_cls: Returns: """ with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: tokens = ' '.join(row[1:]) tokens = tokenizer(tokens) if yield_cls: yield int(row[0]) - 1, ngrams_iterator(tokens, ngrams) else: yield ngrams_iterator(tokens, ngrams)
def test_text_nomalize_function(self): # Test text_nomalize function in torchtext.datasets.text_classification ref_lines = [] test_lines = [] tokenizer = data.get_tokenizer("basic_english") data_path = 'test/asset/text_normalization_ag_news_test.csv' with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: test_lines.append(tokenizer(' , '.join(row))) data_path = 'test/asset/text_normalization_ag_news_ref_results.test' with io.open(data_path, encoding="utf8") as ref_data: for line in ref_data: line = line.split() self.assertEqual(line[0][:9], '__label__') line[0] = line[0][9:] # remove '__label__' ref_lines.append(line) self.assertEqual(ref_lines, test_lines)
def test_make_fields_1(self): path = os.path.join(test_dir_path, 'test_datasets') a_dataset = 'sample_table_large.csv' with io.open(os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: header = next(unicode_csv_reader(f)) self.assertEqual(header, [ '_id', 'ltable_id', 'rtable_id', 'label', 'ltable_Song_Name', 'ltable_Artist_Name', 'ltable_Price', 'ltable_Released', 'rtable_Song_Name', 'rtable_Artist_Name', 'rtable_Price', 'rtable_Released' ]) id_attr = '_id' label_attr = 'label' fields = _make_fields(header, id_attr, label_attr, ['ltable_id', 'rtable_id'], True, 'nltk', True) self.assertEqual(len(fields), 12) counter = {} for tup in fields: if tup[1] not in counter: counter[tup[1]] = 0 counter[tup[1]] += 1 self.assertEqual(sorted(list(counter.values())), [1, 1, 2, 8])
def setUp(self): self.data_dir = os.path.join(test_dir_path, 'test_datasets') self.train = 'test_train.csv' self.validation = 'test_valid.csv' self.test = 'test_test.csv' self.cache_name = 'test_cacheddata.pth' with io.open(os.path.expanduser(os.path.join(self.data_dir, self.train)), encoding="utf8") as f: header = next(unicode_csv_reader(f)) id_attr = 'id' label_attr = 'label' ignore_columns = ['left_id', 'right_id'] self.fields = _make_fields(header, id_attr, label_attr, ignore_columns, True, 'nltk', False) self.column_naming = { 'id': id_attr, 'left': 'left_', 'right': 'right_', 'label': label_attr }
def _csv_iterator(data_path, ngrams, skip_header=True, yield_cls=False, label_col=6, token_col=[1, 5], label_mapping={ "simulation": 0, "hardware": 1, "edge_computing": 2 }): tokenizer = get_tokenizer("spacy", "en_core_web_sm") with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) if skip_header: next(reader, None) for row in reader: tokens = ' '.join([j for i, j in enumerate(row) if i in token_col]) tokens = tokenizer(tokens) if yield_cls: yield label_mapping[row[label_col]], ngrams_iterator( tokens, ngrams) else: yield ngrams_iterator(tokens, ngrams)
def _create_data_from_csv(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield int(row[0]), ' '.join(row[1:])
def _read_text_iterator(path): with io.open(path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield " ".join(row)
def process( path, train=None, validation=None, test=None, cache="cacheddata.pth", check_cached_data=True, auto_rebuild_cache=True, tokenize="nltk", lowercase=True, embeddings="fasttext.en.bin", embeddings_cache_path="~/.vector_cache", ignore_columns=(), include_lengths=True, id_attr="id", label_attr="label", left_prefix="left_", right_prefix="right_", use_magellan_convention=False, pca=True, ): r"""Creates dataset objects for multiple splits of a dataset. This involves the following steps (if data cannot be retrieved from the cache): #. Read CSV header of a data file and verify header is sane. #. Create fields, i.e., column processing specifications (e.g. tokenization, label conversion to integers etc.) #. Load each data file: #. Read each example (tuple pair) in specified CSV file. #. Preprocess example. Involves lowercasing and tokenization (unless disabled). #. Compute metadata if training data file. \ See :meth:`MatchingDataset.compute_metadata` for details. #. Create vocabulary consisting of all tokens in all attributes in all datasets. #. Download word embedding data if necessary. #. Create mapping from each word in vocabulary to its word embedding. #. Compute metadata #. Write to cache Arguments: path (str): Common prefix of the splits' file paths. train (str): Suffix to add to path for the train set. validation (str): Suffix to add to path for the validation set, or None for no validation set. Default is None. test (str): Suffix to add to path for the test set, or None for no test set. Default is None. cache (str): Suffix to add to path for cache file. If `None` disables caching. check_cached_data (bool): Verify that data files haven't changes since the cache was constructed and that relevant field options haven't changed. auto_rebuild_cache (bool): Automatically rebuild the cache if the data files are modified or if the field options change. Defaults to False. tokenize (str): Which tokenizer to use lowercase (bool): Whether to lowercase all words in all attributes. embeddings (str or list): One or more of the following strings: * `fasttext.{lang}.bin`: This uses sub-word level word embeddings based on binary models from "wiki word vectors" released by FastText. {lang} is 'en' or any other 2 letter ISO 639-1 Language Code, or 3 letter ISO 639-2 Code, if the language does not have a 2 letter code. 300d vectors. ``fasttext.en.bin`` is the default. * `fasttext.wiki.vec`: Uses wiki news word vectors released as part of "Advances in Pre-Training Distributed Word Representations" by Mikolov et al. (2018). 300d vectors. * `fasttext.crawl.vec`: Uses Common Crawl word vectors released as part of "Advances in Pre-Training Distributed Word Representations" by Mikolov et al. (2018). 300d vectors. * `glove.6B.{dims}`: Uses uncased Glove trained on Wiki + Gigaword. {dims} is one of (50d, 100d, 200d, or 300d). * `glove.42B.300d`: Uses uncased Glove trained on Common Crawl. 300d vectors. * `glove.840B.300d`: Uses cased Glove trained on Common Crawl. 300d vectors. * `glove.twitter.27B.{dims}`: Uses cased Glove trained on Twitter. {dims} is one of (25d, 50d, 100d, or 200d). embeddings_cache_path (str): Directory to store dowloaded word vector data. ignore_columns (list): A list of columns to ignore in the CSV files. include_lengths (bool): Whether to provide the model with the lengths of each attribute sequence in each batch. If True, length information can be used by the neural network, e.g. when picking the last RNN output of each attribute sequence. id_attr (str): The name of the tuple pair ID column in the CSV file. label_attr (str): The name of the tuple pair match label column in the CSV file. left_prefix (str): The prefix for attribute names belonging to the left table. right_prefix (str): The prefix for attribute names belonging to the right table. use_magellan_convention (bool): Set `id_attr`, `left_prefix`, and `right_prefix` according to Magellan (py_entitymatching Python package) naming conventions. Specifically, set them to be '_id', 'ltable_', and 'rtable_' respectively. pca (bool): Whether to compute PCA for each attribute (needed for SIF model). Defaults to False. Returns: Tuple[MatchingDataset]: Datasets for (train, validation, and test) splits in that order, if provided, or dataset for unlabeled, if provided. """ if use_magellan_convention: id_attr = "_id" left_prefix = "ltable_" right_prefix = "rtable_" # TODO(Sid): check for all datasets to make sure the files exist and have the same schema a_dataset = train or validation or test with io.open(os.path.expanduser(os.path.join(path, a_dataset)), encoding="utf8") as f: header = next(unicode_csv_reader(f)) _maybe_download_nltk_data() _check_header(header, id_attr, left_prefix, right_prefix, label_attr, ignore_columns) fields = _make_fields( header, id_attr, label_attr, ignore_columns, lowercase, tokenize, include_lengths, ) column_naming = { "id": id_attr, "left": left_prefix, "right": right_prefix, "label": label_attr, } datasets = MatchingDataset.splits( path, train, validation, test, fields, embeddings, embeddings_cache_path, column_naming, cache, check_cached_data, auto_rebuild_cache, train_pca=pca, ) # Save additional information to train dataset. datasets[0].ignore_columns = ignore_columns datasets[0].tokenize = tokenize datasets[0].lowercase = lowercase datasets[0].include_lengths = include_lengths return datasets
def __init__(self, path, format, fields, skip_header=False, csv_reader_params=None, **kwargs): """Create a TabularDataset given a path, file format, and field list. Args: path (str): Path to the data file. format (str): The format of the data file. One of "CSV", "TSV", or "JSON" (case-insensitive). fields ((list(tuple(str, Field)) or dict[str: tuple(str, Field)): If using a list, the format must be CSV or TSV, and the values of the list should be tuples of (name, field). The fields should be in the same order as the columns in the CSV or TSV file, while tuples of (name, None) represent columns that will be ignored. If using a dict, the keys should be a subset of the JSON keys or CSV/TSV columns, and the values should be tuples of (name, field). Keys not present in the input dictionary are ignored. This allows the user to rename columns from their JSON/CSV/TSV key names and also enables selecting a subset of columns to load. skip_header (bool): Whether to skip the first line of the input file. csv_reader_params(dict): Parameters to pass to the csv reader. Only relevant when format is csv or tsv. See https://docs.python.org/3/library/csv.html#csv.reader for more details. kwargs (dict): passed to the Dataset parent class. """ if csv_reader_params is None: csv_reader_params = {} format = format.lower() make_example = { 'json': Example.fromJSON, 'dict': Example.fromdict, 'tsv': Example.fromCSV, 'csv': Example.fromCSV}[format] with io.open(os.path.expanduser(path), encoding="utf8") as f: if format == 'csv': reader = unicode_csv_reader(f, **csv_reader_params) elif format == 'tsv': reader = unicode_csv_reader(f, delimiter='\t', **csv_reader_params) else: reader = f if format in ['csv', 'tsv'] and isinstance(fields, dict): if skip_header: raise ValueError('When using a dict to specify fields with a {} file,' 'skip_header must be False and' 'the file must have a header.'.format(format)) header = next(reader) field_to_index = {f: header.index(f) for f in fields.keys()} make_example = partial(make_example, field_to_index=field_to_index) if skip_header: next(reader) examples = [make_example(line, fields) for line in reader] if isinstance(fields, dict): fields, field_dict = [], fields for field in field_dict.values(): if isinstance(field, list): fields.extend(field) else: fields.append(field) super(TabularDataset, self).__init__(examples, fields, **kwargs)