def __init__(self, params, model, num_workers=1, worker_id=0): super(WKTDataLayer, self).__init__(params, model, num_workers, worker_id) self._processed_data_folder = self.params.get('processed_data_folder', 'wkt-processed_data') self._data_root = self.params.get('data_root', None) self.corp = Corpus(self._data_root, self._processed_data_folder) seed_tokens = self.params.get('seed_tokens', 'The').split() self.end_token = self.corp.dictionary.word2idx[ self.corp.dictionary.EOS] self.params['seed_tokens'] = [ self.corp.dictionary.word2idx[seed_token] for seed_token in seed_tokens ] if self.params['mode'] == 'infer': self.corp.content = self.params['seed_tokens'] if self.params['mode'] == 'train': self.batch_size = self.params['batch_size'] self.corp.content = self.corp.train elif self.params['mode'] == 'eval': self.batch_size = self.params['batch_size'] self.corp.content = self.corp.valid else: if len(self.corp.content) < self.params['batch_size']: self.batch_size = len(self.corp.content) else: self.batch_size = self.params['batch_size'] self.vocab_file = (self._processed_data_folder, 'vocab.txt') self.bptt = self.params['bptt'] self.rand_start = self.params.get('rand_start', False) self._map_parallel_calls = self.params.get('map_parallel_calls', 8) self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False) self._prefetch_buffer_size = self.params.get('prefetch_buffer_size', tf.contrib.data.AUTOTUNE) self._shuffle_buffer_size = self.params.get('shuffle_buffer_size', -1) self._num_workers = num_workers self._worker_id = worker_id self.delimiter = self.params.get("delimiter", " ") self._small = self.params.get("small", False) self.start = 0 # load source and target vocabularies to RAM if self._small: if self.params['mode'] == 'eval': self.corp.content = self.corp.content[:200] else: self.corp.content = self.corp.content[:9004] if self.params.get('pad_vocab_to_eight', False): self.corp.content = pad_vocab_to_eight(self.corp.content) self.dataset_size = len(self.corp.content) self.vocab_size = len(self.corp.dictionary.idx2word) self._input_tensors = {}
def __init__(self, params, model, num_workers=1, worker_id=0): super(LMTextDataLayer, self).__init__(params, model, num_workers, worker_id) self._processed_data_folder = self.params.get('processed_data_folder', 'processed_data') self._data_root = self.params.get('data_root', None) self.corp = Corpus(self._data_root, self._processed_data_folder) if self.params['mode'] == 'train': self._batch_size = self.params['batch_size'] self.corp.content = self.corp.train elif self.params['mode'] == 'eval': self._batch_size = self.params['batch_size'] self.corp.content = self.corp.valid else: self._batch_size = 1 self.corp.content = self.corp.test self.vocab_file = (self._processed_data_folder, 'vocab.txt') self.bptt = self.params['bptt'] self.rand_start = self.params.get('rand_start', False) self._map_parallel_calls = self.params.get('map_parallel_calls', 8) self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False) self._prefetch_buffer_size = self.params.get('prefetch_buffer_size', tf.contrib.data.AUTOTUNE) self._shuffle_buffer_size = self.params.get('shuffle_buffer_size', -1) self._num_workers = num_workers self._worker_id = worker_id self.params["delimiter"] = self.params.get("delimiter", " ") self.params["small"] = self.params.get("small", False) self.start = 0 if self._pad_lengths_to_eight and not (self.params['max_length'] % 8 == 0): raise ValueError("If padding to 8 in data layer, then " "max_length should be multiple of 8") # load source and target vocabularies to RAM seed_tokens = self.params.get('seed_tokens', 'The').split() self.params['end_token'] = self.corp.dictionary.word2idx[self.corp.dictionary.EOS] self.params['seed_tokens'] = [self.corp.dictionary.word2idx[seed_token] for seed_token in seed_tokens] if self.params["small"]: if self.params['mode'] == 'eval': self.corp.content = self.corp.content[:200] else: self.corp.content = self.corp.content[:9004] if self.params.get('pad_vocab_to_eight', False): self.corp.content = pad_vocab_to_eight(self.corp.content) if self.params['mode'] == 'infer': if len(self.corp.content) > self.bptt: self.corp.content = self.corp.content[-self.bptt:] self.dataset_size = len(self.corp.content) self.params['vocab_size'] = len(self.corp.dictionary.idx2word) self.PAD_ID = self.params['vocab_size'] self.PAD = '<pad>' self.corp.dictionary.idx2word.append(self.PAD) self.corp.dictionary.word2idx[self.PAD] = self.PAD_ID self._input_tensors = {} self._batch_size
def __init__(self, params, model, num_workers=1, worker_id=0): super(ParallelTextDataLayer, self).__init__(params, model, num_workers, worker_id) self._batch_size = self.params['batch_size'] self.source_file = self.params['source_file'] self._use_targets = self.params.get('use_targets', True) if not self._use_targets: self.target_file = self.source_file if 'target_file' in self.params: print("WARNING: target file was specified but was " "ignored by data layer because 'use_targets'=False") else: self.target_file = self.params['target_file'] self.src_vocab_file = self.params['src_vocab_file'] self.tgt_vocab_file = self.params['tgt_vocab_file'] self.max_len = self.params['max_length'] self._delimiter = self.params.get('delimiter', ' ') self._map_parallel_calls = self.params.get('map_parallel_calls', 8) self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False) self._prefetch_buffer_size = self.params.get('prefetch_buffer_size', tf.contrib.data.AUTOTUNE) self._shuffle_buffer_size = self.params.get('shuffle_buffer_size', -1) self._num_workers = num_workers self._worker_id = worker_id self._use_start_token = self.params.get('use_start_token', True) if self._pad_lengths_to_eight and not (self.params['max_length'] % 8 == 0): raise ValueError("If padding to 8 in data layer, then " "max_length should be multiple of 8") def file_len(fname): with open(fname,encoding="utf-8") as f: for i, l in enumerate(f): pass return i + 1 self.dataset_size = file_len(self.source_file) special_tokens_already_in_vocab = self.params.get('special_tokens_already_in_vocab', True) # load source and target vocabularies to RAM self.src_seq2idx = load_pre_existing_vocabulary( self.src_vocab_file, min_idx=0 if special_tokens_already_in_vocab else SpecialTextTokens.UNK_ID.value + 1) self.tgt_seq2idx = load_pre_existing_vocabulary( self.tgt_vocab_file, min_idx=0 if special_tokens_already_in_vocab else SpecialTextTokens.UNK_ID.value + 1) if not special_tokens_already_in_vocab: # manually add special tokens # unknown symbol self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \ SpecialTextTokens.UNK_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \ SpecialTextTokens.UNK_ID.value # sentence start self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \ SpecialTextTokens.S_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \ SpecialTextTokens.S_ID.value # sentence end self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \ SpecialTextTokens.EOS_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \ SpecialTextTokens.EOS_ID.value # padding self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \ SpecialTextTokens.PAD_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \ SpecialTextTokens.PAD_ID.value if self.params.get('pad_vocab_to_eight', False): self.src_seq2idx = pad_vocab_to_eight(self.src_seq2idx) self.tgt_seq2idx = pad_vocab_to_eight(self.tgt_seq2idx) self.src_idx2seq = {idx: w for w, idx in self.src_seq2idx.items()} self.tgt_idx2seq = {idx: w for w, idx in self.tgt_seq2idx.items()} self.params['src_vocab_size'] = len(self.src_seq2idx) self.params['tgt_vocab_size'] = len(self.tgt_seq2idx) self.params['target_seq2idx'] = self.tgt_seq2idx self.params['source_seq2idx'] = self.src_seq2idx self.params['target_idx2seq'] = self.tgt_idx2seq self.params['source_idx2seq'] = self.src_idx2seq self._input_tensors = {}
def __init__(self, params, model, num_workers=1, worker_id=0): super(ParallelTextDataLayer, self).__init__(params, model, num_workers, worker_id) self._batch_size = self.params['batch_size'] self.source_file = self.params['source_file'] self._use_targets = self.params.get('use_targets', True) if not self._use_targets: self.target_file = self.source_file if 'target_file' in self.params: print("WARNING: target file was specified but was " "ignored by data layer because 'use_targets'=False") else: self.target_file = self.params['target_file'] self.src_vocab_file = self.params['src_vocab_file'] self.tgt_vocab_file = self.params['tgt_vocab_file'] self.max_len = self.params['max_length'] self._delimiter = self.params.get('delimiter', ' ') self._map_parallel_calls = self.params.get('map_parallel_calls', 8) self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False) self._prefetch_buffer_size = self.params.get('prefetch_buffer_size', 4) self._num_workers = num_workers self._worker_id = worker_id if self._pad_lengths_to_eight and not (self.params['max_length'] % 8 == 0): raise ValueError("If padding to 8 in data layer, then " "max_length should be multiple of 8") def file_len(fname): with open(fname) as f: for i, l in enumerate(f): pass return i + 1 self.dataset_size = file_len(self.source_file) # load source and target vocabularies to RAM self.src_seq2idx = load_pre_existing_vocabulary( self.src_vocab_file, min_idx=SpecialTextTokens.UNK_ID.value + 1) self.tgt_seq2idx = load_pre_existing_vocabulary( self.tgt_vocab_file, min_idx=SpecialTextTokens.UNK_ID.value + 1) # unknown symbol self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \ SpecialTextTokens.UNK_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \ SpecialTextTokens.UNK_ID.value # sentence start self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \ SpecialTextTokens.S_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \ SpecialTextTokens.S_ID.value # sentence end self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \ SpecialTextTokens.EOS_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \ SpecialTextTokens.EOS_ID.value # padding self.src_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \ SpecialTextTokens.PAD_ID.value self.tgt_seq2idx[ SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \ SpecialTextTokens.PAD_ID.value if self.params.get('pad_vocab_to_eight', False): self.src_seq2idx = pad_vocab_to_eight(self.src_seq2idx) self.tgt_seq2idx = pad_vocab_to_eight(self.tgt_seq2idx) self.src_idx2seq = {idx: w for w, idx in self.src_seq2idx.items()} self.tgt_idx2seq = {idx: w for w, idx in self.tgt_seq2idx.items()} self.params['src_vocab_size'] = len(self.src_seq2idx) self.params['tgt_vocab_size'] = len(self.tgt_seq2idx) self.params['target_seq2idx'] = self.tgt_seq2idx self.params['source_seq2idx'] = self.src_seq2idx self.params['target_idx2seq'] = self.tgt_idx2seq self.params['source_idx2seq'] = self.src_idx2seq self._input_tensors = {}