Exemplo n.º 1
0
 def read_sents(self, filename, filter_ids=None):
   if self.vocab is None:
     self.vocab = Vocab()
   vocab_reference = self.vocab if self.include_vocab_reference else None
   return map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in l.strip().split()] + \
                                            [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
              self.iterate_filtered(filename, filter_ids))
Exemplo n.º 2
0
  def read_sents(self, filename, filter_ids=None):
    if self.vocab is None:
      self.vocab = Vocab()
    def convert(line, segmentation):
      line = line.strip().split()
      ret = AnnotatedSentenceInput(list(map(self.vocab.convert, line)) + [self.vocab.convert(Vocab.ES_STR)])
      ret.annotate("segment", list(map(int, segmentation.strip().split())))
      return ret

    if type(filename) != list:
      try:
        filename = ast.literal_eval(filename)
      except:
        logger.debug("Reading %s with a PlainTextReader instead..." % filename)
        return super(SegmentationTextReader, self).read_sents(filename)

    max_id = None
    if filter_ids is not None:
      max_id = max(filter_ids)
      filter_ids = set(filter_ids)
    data = []
    with open(filename[0], encoding='utf-8') as char_inp,\
         open(filename[1], encoding='utf-8') as seg_inp:
      for sent_count, (char_line, seg_line) in enumerate(zip(char_inp, seg_inp)):
        if filter_ids is None or sent_count in filter_ids:
          data.append(convert(char_line, seg_line))
        if max_id is not None and sent_count > max_id:
          break
    return data
Exemplo n.º 3
0
 def test_lookup_composer(self):
     enc = self.segmenting_encoder
     word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
     word_vocab.freeze()
     enc.segment_composer = LookupComposer(word_vocab=word_vocab,
                                           src_vocab=self.src_reader.vocab,
                                           hidden_dim=self.layer_dim)
     enc.transduce(self.inp_emb(0))
Exemplo n.º 4
0
 def test_add_multiple_segment_composer(self):
     enc = self.segmenting_encoder
     word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
     word_vocab.freeze()
     enc.segment_composer = SumMultipleComposer(composers=[
         LookupComposer(word_vocab=word_vocab,
                        src_vocab=self.src_reader.vocab,
                        hidden_dim=self.layer_dim),
         CharNGramComposer(word_vocab=word_vocab,
                           src_vocab=self.src_reader.vocab,
                           hidden_dim=self.layer_dim)
     ])
     enc.transduce(self.inp_emb(0))
Exemplo n.º 5
0
  def setUp(self):
    xnmt.events.clear()
    self.hyp = ["the taro met the hanako".split()]
    self.ref = ["taro met hanako".split()]

    vocab = Vocab()
    self.hyp_id = list(map(vocab.convert, self.hyp[0]))
    self.ref_id = list(map(vocab.convert, self.ref[0]))
Exemplo n.º 6
0
 def read_sents(self, filename, filter_ids=None):
     """
 :param filename: data file
 :param filter_ids: only read sentences with these ids (0-indexed)
 :returns: iterator over sentences from filename
 """
     if self.vocab is None:
         self.vocab = Vocab()
     return self.iterate_filtered(filename, filter_ids)
Exemplo n.º 7
0
  def read_sents(self, filename: str, filter_ids: Sequence[int] = None) -> Iterator[xnmt.input.Input]:
    """
    Read sentences and return an iterator.

    Args:
      filename: data file
      filter_ids: only read sentences with these ids (0-indexed)
    Returns: iterator over sentences from filename
    """
    if self.vocab is None:
      self.vocab = Vocab()
    return self.iterate_filtered(filename, filter_ids)
Exemplo n.º 8
0
class PlainTextReader(BaseTextReader, Serializable):
    """
  Handles the typical case of reading plain text files,
  with one sent per line.
  
  Args:
    vocab (Vocab): turns tokens strings into token IDs
    include_vocab_reference (bool): TODO document me
  """
    yaml_tag = '!PlainTextReader'

    def __init__(self, vocab=None, include_vocab_reference=False):
        self.vocab = vocab
        self.include_vocab_reference = include_vocab_reference
        if vocab is not None:
            self.vocab.freeze()
            self.vocab.set_unk(Vocab.UNK_STR)

    def read_sents(self, filename, filter_ids=None):
        if self.vocab is None:
            self.vocab = Vocab()
        vocab_reference = self.vocab if self.include_vocab_reference else None
        return map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in l.strip().split()] + \
                                                 [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
                   self.iterate_filtered(filename, filter_ids))

    def count_words(self, trg_words):
        trg_cnt = 0
        for x in trg_words:
            if type(x) == int:
                trg_cnt += 1 if x != Vocab.ES else 0
            else:
                trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x])
        return trg_cnt

    def vocab_size(self):
        return len(self.vocab)
Exemplo n.º 9
0
seed = 13
random.seed(seed)
np.random.seed(seed)

EXP_DIR = os.path.dirname(__file__)
EXP = "programmatic"

model_file = f"{EXP_DIR}/models/{EXP}.mod"
log_file = f"{EXP_DIR}/logs/{EXP}.log"

xnmt.tee.set_out_file(log_file)

ParamManager.init_param_col()
ParamManager.param_col.model_file = model_file

src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab")

batcher = SrcBatcher(batch_size=64)

inference = SimpleInference(batcher=batcher)

layer_dim = 512

model = DefaultTranslator(
    src_reader=PlainTextReader(vocab=src_vocab),
    trg_reader=PlainTextReader(vocab=trg_vocab),
    src_embedder=SimpleWordEmbedder(emb_dim=layer_dim,
                                    vocab_size=len(src_vocab)),
    encoder=BiLSTMSeqTransducer(input_dim=layer_dim,
                                hidden_dim=layer_dim,
Exemplo n.º 10
0
  def read_sents(self, filename, filter_ids=None):
    def convert_binary_bracketing(parse, lowercase=False):
      transitions = []
      tokens = []
      for word in parse.split(' '):
          if word[0] != "(":
              if word == ")":
                  transitions.append(1)
              else:
                  # Downcase all words to match GloVe.
                  if lowercase:
                      tokens.append(word.lower())
                  else:
                      tokens.append(word)
                  transitions.append(0)
      type="rb"
      if type=="gt":
        transitions=transitions
      elif type=="lb":
        transitions=lb_build(len(tokens))
      elif type=="rb":
        transitions=rb_build(len(tokens))
      elif type=="bal":
        transitions=balanced_transitions(len(tokens))
      else:
        print("Invalid")
      return tokens, transitions
    def get_tokens(parse):
      return convert_binary_bracketing(parse)[0]

    def get_transitions(parse):
      return convert_binary_bracketing(parse)[1]
    def lb_build(N):
      if N==2:
          return [0,0,1]
      else:
          return lb_build(N-1)+[0,1]
    def rb_build(N):
      return [0]*N+[1]*(N-1)
    def balanced_transitions(N):
        """
        Recursively creates a balanced binary tree with N
        leaves using shift reduce transitions.
        """
        if N == 3:
            return [0, 0, 1, 0, 1]
        elif N == 2:
            return [0, 0, 1]
        elif N == 1:
            return [0]
        else:
            right_N = N // 2
            left_N = N - right_N
            return balanced_transitions(left_N) + balanced_transitions(right_N) + [1]

    if self.vocab is None:
      self.vocab = Vocab()
    vocab_reference = self.vocab if self.include_vocab_reference else None
    #import pdb;pdb.set_trace()
    r1=map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in get_tokens(l.strip())] + \
                                             [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
               self.iterate_filtered(filename, filter_ids))
    r2=map(lambda l: get_transitions(l.strip()),
               self.iterate_filtered(filename, filter_ids))
    return [r1,r2]
Exemplo n.º 11
0
class TreeTextReader(BaseTextReader, Serializable):
  """
  Handles the typical case of reading plain text files,
  with one sent per line.
  
  Args:
    vocab (Vocab): turns tokens strings into token IDs
    include_vocab_reference (bool): TODO document me
  """
  yaml_tag = '!TreeTextReader'
  def __init__(self, vocab=None, include_vocab_reference=False):
    self.vocab = vocab
    self.include_vocab_reference = include_vocab_reference
    if vocab is not None:
      self.vocab.freeze()
      self.vocab.set_unk(Vocab.UNK_STR)

  def read_sents(self, filename, filter_ids=None):
    def convert_binary_bracketing(parse, lowercase=False):
      transitions = []
      tokens = []
      for word in parse.split(' '):
          if word[0] != "(":
              if word == ")":
                  transitions.append(1)
              else:
                  # Downcase all words to match GloVe.
                  if lowercase:
                      tokens.append(word.lower())
                  else:
                      tokens.append(word)
                  transitions.append(0)
      type="rb"
      if type=="gt":
        transitions=transitions
      elif type=="lb":
        transitions=lb_build(len(tokens))
      elif type=="rb":
        transitions=rb_build(len(tokens))
      elif type=="bal":
        transitions=balanced_transitions(len(tokens))
      else:
        print("Invalid")
      return tokens, transitions
    def get_tokens(parse):
      return convert_binary_bracketing(parse)[0]

    def get_transitions(parse):
      return convert_binary_bracketing(parse)[1]
    def lb_build(N):
      if N==2:
          return [0,0,1]
      else:
          return lb_build(N-1)+[0,1]
    def rb_build(N):
      return [0]*N+[1]*(N-1)
    def balanced_transitions(N):
        """
        Recursively creates a balanced binary tree with N
        leaves using shift reduce transitions.
        """
        if N == 3:
            return [0, 0, 1, 0, 1]
        elif N == 2:
            return [0, 0, 1]
        elif N == 1:
            return [0]
        else:
            right_N = N // 2
            left_N = N - right_N
            return balanced_transitions(left_N) + balanced_transitions(right_N) + [1]

    if self.vocab is None:
      self.vocab = Vocab()
    vocab_reference = self.vocab if self.include_vocab_reference else None
    #import pdb;pdb.set_trace()
    r1=map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in get_tokens(l.strip())] + \
                                             [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
               self.iterate_filtered(filename, filter_ids))
    r2=map(lambda l: get_transitions(l.strip()),
               self.iterate_filtered(filename, filter_ids))
    return [r1,r2]

  def count_words(self, trg_words):
    trg_cnt = 0
    for x in trg_words:
      if type(x) == int:
        trg_cnt += 1 if x != Vocab.ES else 0
      else:
        trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x])
    return trg_cnt

  def vocab_size(self):
    return len(self.vocab)