Ejemplo n.º 1
0
    def __init__(self, params, model, num_workers=1, worker_id=0):
        super(WKTDataLayer, self).__init__(params, model, num_workers,
                                           worker_id)

        self._processed_data_folder = self.params.get('processed_data_folder',
                                                      'wkt-processed_data')
        self._data_root = self.params.get('data_root', None)

        self.corp = Corpus(self._data_root, self._processed_data_folder)

        seed_tokens = self.params.get('seed_tokens', 'The').split()

        self.end_token = self.corp.dictionary.word2idx[
            self.corp.dictionary.EOS]
        self.params['seed_tokens'] = [
            self.corp.dictionary.word2idx[seed_token]
            for seed_token in seed_tokens
        ]

        if self.params['mode'] == 'infer':
            self.corp.content = self.params['seed_tokens']

        if self.params['mode'] == 'train':
            self.batch_size = self.params['batch_size']
            self.corp.content = self.corp.train
        elif self.params['mode'] == 'eval':
            self.batch_size = self.params['batch_size']
            self.corp.content = self.corp.valid
        else:
            if len(self.corp.content) < self.params['batch_size']:
                self.batch_size = len(self.corp.content)
            else:
                self.batch_size = self.params['batch_size']

        self.vocab_file = (self._processed_data_folder, 'vocab.txt')
        self.bptt = self.params['bptt']
        self.rand_start = self.params.get('rand_start', False)
        self._map_parallel_calls = self.params.get('map_parallel_calls', 8)
        self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight',
                                                     False)
        self._prefetch_buffer_size = self.params.get('prefetch_buffer_size',
                                                     tf.contrib.data.AUTOTUNE)
        self._shuffle_buffer_size = self.params.get('shuffle_buffer_size', -1)
        self._num_workers = num_workers
        self._worker_id = worker_id
        self.delimiter = self.params.get("delimiter", " ")
        self._small = self.params.get("small", False)
        self.start = 0

        # load source and target vocabularies to RAM
        if self._small:
            if self.params['mode'] == 'eval':
                self.corp.content = self.corp.content[:200]
            else:
                self.corp.content = self.corp.content[:9004]

        if self.params.get('pad_vocab_to_eight', False):
            self.corp.content = pad_vocab_to_eight(self.corp.content)

        self.dataset_size = len(self.corp.content)
        self.vocab_size = len(self.corp.dictionary.idx2word)
        self._input_tensors = {}
Ejemplo n.º 2
0
  def __init__(self, params, model, num_workers=1, worker_id=0):
    super(LMTextDataLayer, self).__init__(params, model,
                                          num_workers, worker_id)

    self._processed_data_folder = self.params.get('processed_data_folder', 'processed_data')
    self._data_root = self.params.get('data_root', None)
    self.corp = Corpus(self._data_root, self._processed_data_folder)
    if self.params['mode'] == 'train':
      self._batch_size = self.params['batch_size']
      self.corp.content = self.corp.train
    elif self.params['mode'] == 'eval':
      self._batch_size = self.params['batch_size']
      self.corp.content = self.corp.valid
    else:
      self._batch_size = 1
      self.corp.content = self.corp.test

    self.vocab_file = (self._processed_data_folder, 'vocab.txt')
    self.bptt = self.params['bptt']
    self.rand_start = self.params.get('rand_start', False)
    self._map_parallel_calls = self.params.get('map_parallel_calls', 8)
    self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False)
    self._prefetch_buffer_size = self.params.get('prefetch_buffer_size',
                                                 tf.contrib.data.AUTOTUNE)
    self._shuffle_buffer_size = self.params.get('shuffle_buffer_size', -1)
    self._num_workers = num_workers
    self._worker_id = worker_id
    self.params["delimiter"] = self.params.get("delimiter", " ")
    self.params["small"] = self.params.get("small", False)
    self.start = 0

    if self._pad_lengths_to_eight and not (self.params['max_length'] % 8 == 0):
      raise ValueError("If padding to 8 in data layer, then "
                       "max_length should be multiple of 8")

    # load source and target vocabularies to RAM

    seed_tokens = self.params.get('seed_tokens', 'The').split()
    
    self.params['end_token'] = self.corp.dictionary.word2idx[self.corp.dictionary.EOS]
    self.params['seed_tokens'] = [self.corp.dictionary.word2idx[seed_token] for seed_token in seed_tokens]
    if self.params["small"]:
      if self.params['mode'] == 'eval':
        self.corp.content = self.corp.content[:200]
      else:
        self.corp.content = self.corp.content[:9004]


    if self.params.get('pad_vocab_to_eight', False):
      self.corp.content = pad_vocab_to_eight(self.corp.content)

    if self.params['mode'] == 'infer':
      if len(self.corp.content) > self.bptt:
        self.corp.content = self.corp.content[-self.bptt:]

    self.dataset_size = len(self.corp.content)

    self.params['vocab_size'] = len(self.corp.dictionary.idx2word)
    self.PAD_ID = self.params['vocab_size']
    self.PAD = '<pad>'
    self.corp.dictionary.idx2word.append(self.PAD)
    self.corp.dictionary.word2idx[self.PAD] = self.PAD_ID

    self._input_tensors = {}
    self._batch_size
Ejemplo n.º 3
0
  def __init__(self, params, model, num_workers=1, worker_id=0):
    super(ParallelTextDataLayer, self).__init__(params, model,
                                                num_workers, worker_id)
    self._batch_size = self.params['batch_size']
    self.source_file = self.params['source_file']
    self._use_targets = self.params.get('use_targets', True)
    if not self._use_targets:
      self.target_file = self.source_file
      if 'target_file' in self.params:
        print("WARNING: target file was specified but was "
              "ignored by data layer because 'use_targets'=False")
    else:
      self.target_file = self.params['target_file']
    self.src_vocab_file = self.params['src_vocab_file']
    self.tgt_vocab_file = self.params['tgt_vocab_file']
    self.max_len = self.params['max_length']
    self._delimiter = self.params.get('delimiter', ' ')
    self._map_parallel_calls = self.params.get('map_parallel_calls', 8)
    self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False)
    self._prefetch_buffer_size = self.params.get('prefetch_buffer_size',
                                                 tf.contrib.data.AUTOTUNE)
    self._shuffle_buffer_size = self.params.get('shuffle_buffer_size', -1)
    self._num_workers = num_workers
    self._worker_id = worker_id
    self._use_start_token = self.params.get('use_start_token', True)
    if self._pad_lengths_to_eight and not (self.params['max_length'] % 8 == 0):
      raise ValueError("If padding to 8 in data layer, then "
                       "max_length should be multiple of 8")

    def file_len(fname):
      with open(fname,encoding="utf-8") as f:
        for i, l in enumerate(f):
          pass
      return i + 1

    self.dataset_size = file_len(self.source_file)
    special_tokens_already_in_vocab = self.params.get('special_tokens_already_in_vocab', True)

    # load source and target vocabularies to RAM
    self.src_seq2idx = load_pre_existing_vocabulary(
      self.src_vocab_file, min_idx=0 if special_tokens_already_in_vocab
      else SpecialTextTokens.UNK_ID.value + 1)
    self.tgt_seq2idx = load_pre_existing_vocabulary(
      self.tgt_vocab_file, min_idx=0 if special_tokens_already_in_vocab
      else SpecialTextTokens.UNK_ID.value + 1)

    if not special_tokens_already_in_vocab:
      # manually add special tokens
      # unknown symbol
      self.src_seq2idx[
        SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \
        SpecialTextTokens.UNK_ID.value
      self.tgt_seq2idx[
        SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \
        SpecialTextTokens.UNK_ID.value
      # sentence start
      self.src_seq2idx[
        SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \
        SpecialTextTokens.S_ID.value
      self.tgt_seq2idx[
        SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \
        SpecialTextTokens.S_ID.value
      # sentence end
      self.src_seq2idx[
        SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \
        SpecialTextTokens.EOS_ID.value
      self.tgt_seq2idx[
        SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \
        SpecialTextTokens.EOS_ID.value
      # padding
      self.src_seq2idx[
        SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \
        SpecialTextTokens.PAD_ID.value
      self.tgt_seq2idx[
        SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \
        SpecialTextTokens.PAD_ID.value

    if self.params.get('pad_vocab_to_eight', False):
      self.src_seq2idx = pad_vocab_to_eight(self.src_seq2idx)
      self.tgt_seq2idx = pad_vocab_to_eight(self.tgt_seq2idx)

    self.src_idx2seq = {idx: w for w, idx in self.src_seq2idx.items()}
    self.tgt_idx2seq = {idx: w for w, idx in self.tgt_seq2idx.items()}

    self.params['src_vocab_size'] = len(self.src_seq2idx)
    self.params['tgt_vocab_size'] = len(self.tgt_seq2idx)
    self.params['target_seq2idx'] = self.tgt_seq2idx
    self.params['source_seq2idx'] = self.src_seq2idx
    self.params['target_idx2seq'] = self.tgt_idx2seq
    self.params['source_idx2seq'] = self.src_idx2seq

    self._input_tensors = {}
Ejemplo n.º 4
0
  def __init__(self, params, model, num_workers=1, worker_id=0):
    super(ParallelTextDataLayer, self).__init__(params, model,
                                                num_workers, worker_id)
    self._batch_size = self.params['batch_size']
    self.source_file = self.params['source_file']
    self._use_targets = self.params.get('use_targets', True)
    if not self._use_targets:
      self.target_file = self.source_file
      if 'target_file' in self.params:
        print("WARNING: target file was specified but was "
              "ignored by data layer because 'use_targets'=False")
    else:
      self.target_file = self.params['target_file']
    self.src_vocab_file = self.params['src_vocab_file']
    self.tgt_vocab_file = self.params['tgt_vocab_file']
    self.max_len = self.params['max_length']
    self._delimiter = self.params.get('delimiter', ' ')
    self._map_parallel_calls = self.params.get('map_parallel_calls', 8)
    self._pad_lengths_to_eight = self.params.get('pad_lengths_to_eight', False)
    self._prefetch_buffer_size = self.params.get('prefetch_buffer_size', 4)
    self._num_workers = num_workers
    self._worker_id = worker_id
    if self._pad_lengths_to_eight and not (self.params['max_length'] % 8 == 0):
      raise ValueError("If padding to 8 in data layer, then "
                       "max_length should be multiple of 8")

    def file_len(fname):
      with open(fname) as f:
        for i, l in enumerate(f):
          pass
      return i + 1

    self.dataset_size = file_len(self.source_file)

    # load source and target vocabularies to RAM
    self.src_seq2idx = load_pre_existing_vocabulary(
      self.src_vocab_file,
      min_idx=SpecialTextTokens.UNK_ID.value + 1)
    self.tgt_seq2idx = load_pre_existing_vocabulary(
      self.tgt_vocab_file,
      min_idx=SpecialTextTokens.UNK_ID.value + 1)

    # unknown symbol
    self.src_seq2idx[
      SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \
      SpecialTextTokens.UNK_ID.value
    self.tgt_seq2idx[
      SpecialTextTokens.to_string(SpecialTextTokens.UNK_ID.value)] = \
      SpecialTextTokens.UNK_ID.value

    # sentence start
    self.src_seq2idx[
      SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \
      SpecialTextTokens.S_ID.value
    self.tgt_seq2idx[
      SpecialTextTokens.to_string(SpecialTextTokens.S_ID.value)] = \
      SpecialTextTokens.S_ID.value
    # sentence end
    self.src_seq2idx[
      SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \
      SpecialTextTokens.EOS_ID.value
    self.tgt_seq2idx[
      SpecialTextTokens.to_string(SpecialTextTokens.EOS_ID.value)] = \
      SpecialTextTokens.EOS_ID.value
    # padding
    self.src_seq2idx[
      SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \
      SpecialTextTokens.PAD_ID.value
    self.tgt_seq2idx[
      SpecialTextTokens.to_string(SpecialTextTokens.PAD_ID.value)] = \
      SpecialTextTokens.PAD_ID.value

    if self.params.get('pad_vocab_to_eight', False):
      self.src_seq2idx = pad_vocab_to_eight(self.src_seq2idx)
      self.tgt_seq2idx = pad_vocab_to_eight(self.tgt_seq2idx)

    self.src_idx2seq = {idx: w for w, idx in self.src_seq2idx.items()}
    self.tgt_idx2seq = {idx: w for w, idx in self.tgt_seq2idx.items()}

    self.params['src_vocab_size'] = len(self.src_seq2idx)
    self.params['tgt_vocab_size'] = len(self.tgt_seq2idx)
    self.params['target_seq2idx'] = self.tgt_seq2idx
    self.params['source_seq2idx'] = self.src_seq2idx
    self.params['target_idx2seq'] = self.tgt_idx2seq
    self.params['source_idx2seq'] = self.src_idx2seq

    self._input_tensors = {}