def test_lookup_composer(self): enc = self.segmenting_encoder word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab") word_vocab.freeze() enc.segment_composer = LookupComposer(word_vocab=word_vocab, src_vocab=self.src_reader.vocab, hidden_dim=self.layer_dim) enc.transduce(self.inp_emb(0))
def test_add_multiple_segment_composer(self): enc = self.segmenting_encoder word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab") word_vocab.freeze() enc.segment_composer = SumMultipleComposer(composers=[ LookupComposer(word_vocab=word_vocab, src_vocab=self.src_reader.vocab, hidden_dim=self.layer_dim), CharNGramComposer(word_vocab=word_vocab, src_vocab=self.src_reader.vocab, hidden_dim=self.layer_dim) ]) enc.transduce(self.inp_emb(0))
class PlainTextReader(BaseTextReader, Serializable): """ Handles the typical case of reading plain text files, with one sent per line. """ yaml_tag = u'!PlainTextReader' def __init__(self, vocab=None, include_vocab_reference=False): self.vocab = vocab self.include_vocab_reference = include_vocab_reference if vocab is not None: self.vocab.freeze() self.vocab.set_unk(Vocab.UNK_STR) def read_sents(self, filename, filter_ids=None): if self.vocab is None: self.vocab = Vocab() vocab_reference = self.vocab if self.include_vocab_reference else None return six.moves.map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in l.strip().split()] + \ [self.vocab.convert(Vocab.ES_STR)], vocab_reference), self.iterate_filtered(filename, filter_ids)) def freeze(self): self.vocab.freeze() self.vocab.set_unk(Vocab.UNK_STR) self.overwrite_serialize_param("vocab", self.vocab) def count_words(self, trg_words): trg_cnt = 0 for x in trg_words: if type(x) == int: trg_cnt += 1 if x != Vocab.ES else 0 else: trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x]) return trg_cnt def vocab_size(self): return len(self.vocab)