def read_sents(self, filename, filter_ids=None): if self.vocab is None: self.vocab = Vocab() vocab_reference = self.vocab if self.include_vocab_reference else None return map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in l.strip().split()] + \ [self.vocab.convert(Vocab.ES_STR)], vocab_reference), self.iterate_filtered(filename, filter_ids))
def read_sents(self, filename, filter_ids=None): if self.vocab is None: self.vocab = Vocab() def convert(line, segmentation): line = line.strip().split() ret = AnnotatedSentenceInput(list(map(self.vocab.convert, line)) + [self.vocab.convert(Vocab.ES_STR)]) ret.annotate("segment", list(map(int, segmentation.strip().split()))) return ret if type(filename) != list: try: filename = ast.literal_eval(filename) except: logger.debug("Reading %s with a PlainTextReader instead..." % filename) return super(SegmentationTextReader, self).read_sents(filename) max_id = None if filter_ids is not None: max_id = max(filter_ids) filter_ids = set(filter_ids) data = [] with open(filename[0], encoding='utf-8') as char_inp,\ open(filename[1], encoding='utf-8') as seg_inp: for sent_count, (char_line, seg_line) in enumerate(zip(char_inp, seg_inp)): if filter_ids is None or sent_count in filter_ids: data.append(convert(char_line, seg_line)) if max_id is not None and sent_count > max_id: break return data
def test_lookup_composer(self): enc = self.segmenting_encoder word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab") word_vocab.freeze() enc.segment_composer = LookupComposer(word_vocab=word_vocab, src_vocab=self.src_reader.vocab, hidden_dim=self.layer_dim) enc.transduce(self.inp_emb(0))
def test_add_multiple_segment_composer(self): enc = self.segmenting_encoder word_vocab = Vocab(vocab_file="examples/data/head.ja.vocab") word_vocab.freeze() enc.segment_composer = SumMultipleComposer(composers=[ LookupComposer(word_vocab=word_vocab, src_vocab=self.src_reader.vocab, hidden_dim=self.layer_dim), CharNGramComposer(word_vocab=word_vocab, src_vocab=self.src_reader.vocab, hidden_dim=self.layer_dim) ]) enc.transduce(self.inp_emb(0))
def setUp(self): xnmt.events.clear() self.hyp = ["the taro met the hanako".split()] self.ref = ["taro met hanako".split()] vocab = Vocab() self.hyp_id = list(map(vocab.convert, self.hyp[0])) self.ref_id = list(map(vocab.convert, self.ref[0]))
def read_sents(self, filename, filter_ids=None): """ :param filename: data file :param filter_ids: only read sentences with these ids (0-indexed) :returns: iterator over sentences from filename """ if self.vocab is None: self.vocab = Vocab() return self.iterate_filtered(filename, filter_ids)
def read_sents(self, filename: str, filter_ids: Sequence[int] = None) -> Iterator[xnmt.input.Input]: """ Read sentences and return an iterator. Args: filename: data file filter_ids: only read sentences with these ids (0-indexed) Returns: iterator over sentences from filename """ if self.vocab is None: self.vocab = Vocab() return self.iterate_filtered(filename, filter_ids)
class PlainTextReader(BaseTextReader, Serializable): """ Handles the typical case of reading plain text files, with one sent per line. Args: vocab (Vocab): turns tokens strings into token IDs include_vocab_reference (bool): TODO document me """ yaml_tag = '!PlainTextReader' def __init__(self, vocab=None, include_vocab_reference=False): self.vocab = vocab self.include_vocab_reference = include_vocab_reference if vocab is not None: self.vocab.freeze() self.vocab.set_unk(Vocab.UNK_STR) def read_sents(self, filename, filter_ids=None): if self.vocab is None: self.vocab = Vocab() vocab_reference = self.vocab if self.include_vocab_reference else None return map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in l.strip().split()] + \ [self.vocab.convert(Vocab.ES_STR)], vocab_reference), self.iterate_filtered(filename, filter_ids)) def count_words(self, trg_words): trg_cnt = 0 for x in trg_words: if type(x) == int: trg_cnt += 1 if x != Vocab.ES else 0 else: trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x]) return trg_cnt def vocab_size(self): return len(self.vocab)
seed = 13 random.seed(seed) np.random.seed(seed) EXP_DIR = os.path.dirname(__file__) EXP = "programmatic" model_file = f"{EXP_DIR}/models/{EXP}.mod" log_file = f"{EXP_DIR}/logs/{EXP}.log" xnmt.tee.set_out_file(log_file) ParamManager.init_param_col() ParamManager.param_col.model_file = model_file src_vocab = Vocab(vocab_file="examples/data/head.ja.vocab") trg_vocab = Vocab(vocab_file="examples/data/head.en.vocab") batcher = SrcBatcher(batch_size=64) inference = SimpleInference(batcher=batcher) layer_dim = 512 model = DefaultTranslator( src_reader=PlainTextReader(vocab=src_vocab), trg_reader=PlainTextReader(vocab=trg_vocab), src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=len(src_vocab)), encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim,
def read_sents(self, filename, filter_ids=None): def convert_binary_bracketing(parse, lowercase=False): transitions = [] tokens = [] for word in parse.split(' '): if word[0] != "(": if word == ")": transitions.append(1) else: # Downcase all words to match GloVe. if lowercase: tokens.append(word.lower()) else: tokens.append(word) transitions.append(0) type="rb" if type=="gt": transitions=transitions elif type=="lb": transitions=lb_build(len(tokens)) elif type=="rb": transitions=rb_build(len(tokens)) elif type=="bal": transitions=balanced_transitions(len(tokens)) else: print("Invalid") return tokens, transitions def get_tokens(parse): return convert_binary_bracketing(parse)[0] def get_transitions(parse): return convert_binary_bracketing(parse)[1] def lb_build(N): if N==2: return [0,0,1] else: return lb_build(N-1)+[0,1] def rb_build(N): return [0]*N+[1]*(N-1) def balanced_transitions(N): """ Recursively creates a balanced binary tree with N leaves using shift reduce transitions. """ if N == 3: return [0, 0, 1, 0, 1] elif N == 2: return [0, 0, 1] elif N == 1: return [0] else: right_N = N // 2 left_N = N - right_N return balanced_transitions(left_N) + balanced_transitions(right_N) + [1] if self.vocab is None: self.vocab = Vocab() vocab_reference = self.vocab if self.include_vocab_reference else None #import pdb;pdb.set_trace() r1=map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in get_tokens(l.strip())] + \ [self.vocab.convert(Vocab.ES_STR)], vocab_reference), self.iterate_filtered(filename, filter_ids)) r2=map(lambda l: get_transitions(l.strip()), self.iterate_filtered(filename, filter_ids)) return [r1,r2]
class TreeTextReader(BaseTextReader, Serializable): """ Handles the typical case of reading plain text files, with one sent per line. Args: vocab (Vocab): turns tokens strings into token IDs include_vocab_reference (bool): TODO document me """ yaml_tag = '!TreeTextReader' def __init__(self, vocab=None, include_vocab_reference=False): self.vocab = vocab self.include_vocab_reference = include_vocab_reference if vocab is not None: self.vocab.freeze() self.vocab.set_unk(Vocab.UNK_STR) def read_sents(self, filename, filter_ids=None): def convert_binary_bracketing(parse, lowercase=False): transitions = [] tokens = [] for word in parse.split(' '): if word[0] != "(": if word == ")": transitions.append(1) else: # Downcase all words to match GloVe. if lowercase: tokens.append(word.lower()) else: tokens.append(word) transitions.append(0) type="rb" if type=="gt": transitions=transitions elif type=="lb": transitions=lb_build(len(tokens)) elif type=="rb": transitions=rb_build(len(tokens)) elif type=="bal": transitions=balanced_transitions(len(tokens)) else: print("Invalid") return tokens, transitions def get_tokens(parse): return convert_binary_bracketing(parse)[0] def get_transitions(parse): return convert_binary_bracketing(parse)[1] def lb_build(N): if N==2: return [0,0,1] else: return lb_build(N-1)+[0,1] def rb_build(N): return [0]*N+[1]*(N-1) def balanced_transitions(N): """ Recursively creates a balanced binary tree with N leaves using shift reduce transitions. """ if N == 3: return [0, 0, 1, 0, 1] elif N == 2: return [0, 0, 1] elif N == 1: return [0] else: right_N = N // 2 left_N = N - right_N return balanced_transitions(left_N) + balanced_transitions(right_N) + [1] if self.vocab is None: self.vocab = Vocab() vocab_reference = self.vocab if self.include_vocab_reference else None #import pdb;pdb.set_trace() r1=map(lambda l: SimpleSentenceInput([self.vocab.convert(word) for word in get_tokens(l.strip())] + \ [self.vocab.convert(Vocab.ES_STR)], vocab_reference), self.iterate_filtered(filename, filter_ids)) r2=map(lambda l: get_transitions(l.strip()), self.iterate_filtered(filename, filter_ids)) return [r1,r2] def count_words(self, trg_words): trg_cnt = 0 for x in trg_words: if type(x) == int: trg_cnt += 1 if x != Vocab.ES else 0 else: trg_cnt += sum([1 if y != Vocab.ES else 0 for y in x]) return trg_cnt def vocab_size(self): return len(self.vocab)