Beispiel #1
0
    def generate(self,
                 src,
                 idx,
                 src_mask=None,
                 forced_trg_ids=None,
                 search_strategy=None):
        self.start_sent(src)
        if not xnmt.batcher.is_batched(src):
            src = xnmt.batcher.mark_as_batch([src])
        else:
            assert src_mask is not None
        outputs = []

        trg = SimpleSentenceInput([0])

        if not xnmt.batcher.is_batched(trg):
            trg = xnmt.batcher.mark_as_batch([trg])

        output_actions = []
        score = 0.

        # TODO Fix this with output_one_step and use the appropriate search_strategy
        self.max_len = 100  # This is a temporary hack
        for _ in range(self.max_len):
            dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                        check_validity=settings.CHECK_VALIDITY)
            log_prob_tail = self.calc_loss(src,
                                           trg,
                                           loss_cal=None,
                                           infer_prediction=True)
            ys = np.argmax(log_prob_tail.npvalue(), axis=0).astype('i')
            if ys == Vocab.ES:
                output_actions.append(ys)
                break
            output_actions.append(ys)
            trg = SimpleSentenceInput(output_actions + [0])
            if not xnmt.batcher.is_batched(trg):
                trg = xnmt.batcher.mark_as_batch([trg])

        # In case of reporting
        sents = src[0]
        if self.report_path is not None:
            src_words = [self.reporting_src_vocab[w] for w in sents]
            trg_words = [self.trg_vocab[w] for w in output_actions]
            self.set_report_input(idx, src_words, trg_words)
            self.set_report_resource("src_words", src_words)
            self.set_report_path('{}.{}'.format(self.report_path, str(idx)))
            self.generate_report(self.report_type)

        # Append output to the outputs
        if hasattr(self, "trg_vocab") and self.trg_vocab is not None:
            outputs.append(TextOutput(output_actions, self.trg_vocab))
        else:
            outputs.append((output_actions, score))

        return outputs
Beispiel #2
0
 def read_sent(self, sentence, filter_ids=None):
   chars = []
   segs = []
   offset = 0
   for word in sentence.strip().split():
     offset += len(word)
     segs.append(offset-1)
     chars.extend([c for c in word])
   segs.append(len(chars))
   chars.append(Vocab.ES_STR)
   sent_input = SimpleSentenceInput([self.vocab.convert(c) for c in chars])
   sent_input.segment = segs
   return sent_input
Beispiel #3
0
 def read_sents(self, filename, filter_ids=None):
   if self.vocab is None:
     self.vocab = Vocab()
   vocab_reference = self.vocab if self.include_vocab_reference else None
   return map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in l.strip().split()] + \
                                            [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
              self.iterate_filtered(filename, filter_ids))
Beispiel #4
0
 def read_sent(self, sentence):
   if self.sample_train and self.train:
     words = self.subword_model.SampleEncodeAsPieces(sentence.strip(), self.l, self.alpha)
   else:
     words = self.subword_model.EncodeAsPieces(sentence.strip())
   words = [w.decode('utf-8') for w in words]
   return SimpleSentenceInput([self.vocab.convert(word) for word in words] + \
                                                      [self.vocab.convert(Vocab.ES_STR)])
Beispiel #5
0
 def read_sent(self, line):
   if self.vocab:
     self.convert_fct = self.vocab.convert
   else:
     self.convert_fct = int
   if self.read_sent_len:
     return IntInput(len(line.strip().split()))
   else:
     return SimpleSentenceInput([self.convert_fct(word) for word in line.strip().split()] + [Vocab.ES])
Beispiel #6
0
    def generate(self, src, idx, forced_trg_ids=None, search_strategy=None):
        self.start_sent(src)
        if not xnmt.batcher.is_batched(src):
            src = xnmt.batcher.mark_as_batch([src])
        outputs = []

        trg = SimpleSentenceInput([0])

        if not xnmt.batcher.is_batched(trg):
            trg = xnmt.batcher.mark_as_batch([trg])

        output_actions = []
        score = 0.

        # TODO Fix this with generate_one_step and use the appropriate search_strategy
        self.max_len = 100  # This is a temporary hack
        for _ in range(self.max_len):
            dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                        check_validity=settings.CHECK_VALIDITY)
            log_prob_tail = self.calc_loss(src,
                                           trg,
                                           loss_cal=None,
                                           infer_prediction=True)
            ys = np.argmax(log_prob_tail.npvalue(), axis=0).astype('i')
            if ys == Vocab.ES:
                output_actions.append(ys)
                break
            output_actions.append(ys)
            trg = SimpleSentenceInput(output_actions + [0])
            if not xnmt.batcher.is_batched(trg):
                trg = xnmt.batcher.mark_as_batch([trg])

        # Append output to the outputs
        if hasattr(self, "trg_vocab") and self.trg_vocab is not None:
            outputs.append(
                TextOutput(actions=output_actions, vocab=self.trg_vocab))
        else:
            outputs.append((output_actions, score))

        return outputs
Beispiel #7
0
  def read_sents(self, filename, filter_ids=None):
    def convert_binary_bracketing(parse, lowercase=False):
      transitions = []
      tokens = []
      for word in parse.split(' '):
          if word[0] != "(":
              if word == ")":
                  transitions.append(1)
              else:
                  # Downcase all words to match GloVe.
                  if lowercase:
                      tokens.append(word.lower())
                  else:
                      tokens.append(word)
                  transitions.append(0)
      type="rb"
      if type=="gt":
        transitions=transitions
      elif type=="lb":
        transitions=lb_build(len(tokens))
      elif type=="rb":
        transitions=rb_build(len(tokens))
      elif type=="bal":
        transitions=balanced_transitions(len(tokens))
      else:
        print("Invalid")
      return tokens, transitions
    def get_tokens(parse):
      return convert_binary_bracketing(parse)[0]

    def get_transitions(parse):
      return convert_binary_bracketing(parse)[1]
    def lb_build(N):
      if N==2:
          return [0,0,1]
      else:
          return lb_build(N-1)+[0,1]
    def rb_build(N):
      return [0]*N+[1]*(N-1)
    def balanced_transitions(N):
        """
        Recursively creates a balanced binary tree with N
        leaves using shift reduce transitions.
        """
        if N == 3:
            return [0, 0, 1, 0, 1]
        elif N == 2:
            return [0, 0, 1]
        elif N == 1:
            return [0]
        else:
            right_N = N // 2
            left_N = N - right_N
            return balanced_transitions(left_N) + balanced_transitions(right_N) + [1]

    if self.vocab is None:
      self.vocab = Vocab()
    vocab_reference = self.vocab if self.include_vocab_reference else None
    #import pdb;pdb.set_trace()
    r1=map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in get_tokens(l.strip())] + \
                                             [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
               self.iterate_filtered(filename, filter_ids))
    r2=map(lambda l: get_transitions(l.strip()),
               self.iterate_filtered(filename, filter_ids))
    return [r1,r2]
Beispiel #8
0
 def read_sent(self, sentence, filter_ids=None):
     vocab_reference = self.vocab if self.include_vocab_reference else None
     return SimpleSentenceInput([self.vocab.convert(word) for word in sentence.strip().split()] + \
                                                        [self.vocab.convert(Vocab.ES_STR)], vocab_reference)