예제 #1
0
 def test_lookup_composer(self):
     enc = self.segmenting_encoder
     word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
     word_vocab.freeze()
     enc.segment_composer = LookupComposer(word_vocab=word_vocab,
                                           src_vocab=self.src_reader.vocab,
                                           hidden_dim=self.layer_dim)
     enc.transduce(self.inp_emb(0))
예제 #2
0
 def test_add_multiple_segment_composer(self):
     enc = self.segmenting_encoder
     word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab")
     word_vocab.freeze()
     enc.segment_composer = SumMultipleComposer(composers=[
         LookupComposer(word_vocab=word_vocab,
                        src_vocab=self.src_reader.vocab,
                        hidden_dim=self.layer_dim),
         CharNGramComposer(word_vocab=word_vocab,
                           src_vocab=self.src_reader.vocab,
                           hidden_dim=self.layer_dim)
     ])
     enc.transduce(self.inp_emb(0))
예제 #3
0
파일: input.py 프로젝트: xxcharles/xnmt
class PlainTextReader(BaseTextReader, Serializable):
    """
  Handles the typical case of reading plain text files,
  with one sent per line.
  """
    yaml_tag = u'!PlainTextReader'

    def __init__(self, vocab=None, include_vocab_reference=False):
        self.vocab = vocab
        self.include_vocab_reference = include_vocab_reference
        if vocab is not None:
            self.vocab.freeze()
            self.vocab.set_unk(Vocab.UNK_STR)

    def read_sents(self, filename, filter_ids=None):
        if self.vocab is None:
            self.vocab = Vocab()
        vocab_reference = self.vocab if self.include_vocab_reference else None
        return six.moves.map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in l.strip().split()] + \
                                                           [self.vocab.convert(Vocab.ES_STR)], vocab_reference),
                             self.iterate_filtered(filename, filter_ids))

    def freeze(self):
        self.vocab.freeze()
        self.vocab.set_unk(Vocab.UNK_STR)
        self.overwrite_serialize_param("vocab", self.vocab)

    def count_words(self, trg_words):
        trg_cnt = 0
        for x in trg_words:
            if type(x) == int:
                trg_cnt += 1 if x != Vocab.ES else 0
            else:
                trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x])
        return trg_cnt

    def vocab_size(self):
        return len(self.vocab)