예제 #1
0
파일: input.py 프로젝트: xxcharles/xnmt
class PlainTextReader(BaseTextReader, Serializable):
    """
  Handles the typical case of reading plain text files,
  with one sent per line.
  """
    yaml_tag = u'!PlainTextReader'

    def __init__(self, vocab=None, include_vocab_reference=False):
        self.vocab = vocab
        self.include_vocab_reference = include_vocab_reference
        if vocab is not None:
            self.vocab.freeze()
            self.vocab.set_unk(Vocab.UNK_STR)

    def read_sents(self, filename, filter_ids=None):
        if self.vocab is None:
            self.vocab = Vocab()
        vocab_reference = self.vocab if self.include_vocab_reference else None
        return six.moves.map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in l.strip().split()] + \
                                                           [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
                             self.iterate_filtered(filename, filter_ids))

    def freeze(self):
        self.vocab.freeze()
        self.vocab.set_unk(Vocab.UNK_STR)
        self.overwrite_serialize_param("vocab", self.vocab)

    def count_words(self, trg_words):
        trg_cnt = 0
        for x in trg_words:
            if type(x) == int:
                trg_cnt += 1 if x != Vocab.ES else 0
            else:
                trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x])
        return trg_cnt

    def vocab_size(self):
        return len(self.vocab)
예제 #2
0
class SegmentationTextReader(PlainTextReader):
    yaml_tag = '!SegmentationTextReader'

    # TODO: document me

    @serializable_init
    def __init__(self, vocab=None, include_vocab_reference=False):
        super().__init__(vocab=vocab,
                         include_vocab_reference=include_vocab_reference)

    def read_sents(self, filename, filter_ids=None):
        if self.vocab is None:
            self.vocab = Vocab()

        def convert(line, segmentation):
            line = line.strip().split()
            ret = AnnotatedSentenceInput(
                list(map(self.vocab.convert, line)) +
                [self.vocab.convert(Vocab.ES_STR)])
            ret.annotate("segment", list(map(int,
                                             segmentation.strip().split())))
            return ret

        if type(filename) != list:
            try:
                filename = ast.literal_eval(filename)
            except:
                logger.debug("Reading %s with a PlainTextReader instead..." %
                             filename)
                return super(SegmentationTextReader, self).read_sents(filename)

        max_id = None
        if filter_ids is not None:
            max_id = max(filter_ids)
            filter_ids = set(filter_ids)
        data = []
        with open(filename[0], encoding='utf-8') as char_inp,\
             open(filename[1], encoding='utf-8') as seg_inp:
            for sent_count, (char_line,
                             seg_line) in enumerate(zip(char_inp, seg_inp)):
                if filter_ids is None or sent_count in filter_ids:
                    data.append(convert(char_line, seg_line))
                if max_id is not None and sent_count > max_id:
                    break
        return data

    def count_sents(self, filename):
        return super(SegmentationTextReader, self).count_sents(filename[0])
예제 #3
0
class TreeTextReader(BaseTextReader, Serializable):
  """
  Handles the typical case of reading plain text files,
  with one sent per line.
  
  Args:
    vocab (Vocab): turns tokens strings into token IDs
    include_vocab_reference (bool): TODO document me
  """
  yaml_tag = '!TreeTextReader'
  def __init__(self, vocab=None, include_vocab_reference=False):
    self.vocab = vocab
    self.include_vocab_reference = include_vocab_reference
    if vocab is not None:
      self.vocab.freeze()
      self.vocab.set_unk(Vocab.UNK_STR)

  def read_sents(self, filename, filter_ids=None):
    def convert_binary_bracketing(parse, lowercase=False):
      transitions = []
      tokens = []
      for word in parse.split(' '):
          if word[0] != "(":
              if word == ")":
                  transitions.append(1)
              else:
                  # Downcase all words to match GloVe.
                  if lowercase:
                      tokens.append(word.lower())
                  else:
                      tokens.append(word)
                  transitions.append(0)
      type="rb"
      if type=="gt":
        transitions=transitions
      elif type=="lb":
        transitions=lb_build(len(tokens))
      elif type=="rb":
        transitions=rb_build(len(tokens))
      elif type=="bal":
        transitions=balanced_transitions(len(tokens))
      else:
        print("Invalid")
      return tokens, transitions
    def get_tokens(parse):
      return convert_binary_bracketing(parse)[0]

    def get_transitions(parse):
      return convert_binary_bracketing(parse)[1]
    def lb_build(N):
      if N==2:
          return [0,0,1]
      else:
          return lb_build(N-1)+[0,1]
    def rb_build(N):
      return [0]*N+[1]*(N-1)
    def balanced_transitions(N):
        """
        Recursively creates a balanced binary tree with N
        leaves using shift reduce transitions.
        """
        if N == 3:
            return [0, 0, 1, 0, 1]
        elif N == 2:
            return [0, 0, 1]
        elif N == 1:
            return [0]
        else:
            right_N = N // 2
            left_N = N - right_N
            return balanced_transitions(left_N) + balanced_transitions(right_N) + [1]

    if self.vocab is None:
      self.vocab = Vocab()
    vocab_reference = self.vocab if self.include_vocab_reference else None
    #import pdb;pdb.set_trace()
    r1=map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in get_tokens(l.strip())] + \
                                             [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
               self.iterate_filtered(filename, filter_ids))
    r2=map(lambda l: get_transitions(l.strip()),
               self.iterate_filtered(filename, filter_ids))
    return [r1,r2]

  def count_words(self, trg_words):
    trg_cnt = 0
    for x in trg_words:
      if type(x) == int:
        trg_cnt += 1 if x != Vocab.ES else 0
      else:
        trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x])
    return trg_cnt

  def vocab_size(self):
    return len(self.vocab)