def test_tf_decoder(self): texts = [ "i went to chocolate @END@", "awesome is @END@", "the meaning of life @END@" ] se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) class Model(TransitionModel): def forward(self, x: BasicDecoderState): outprobs = torch.rand(len(x), x.query_encoder.vocab.number_of_ids()) return outprobs, x dec = SeqDecoder(TFTransition(Model())) y = dec(x) print(y[1].followed_actions) outactions = y[1].followed_actions.detach().cpu().numpy() print(outactions[0]) print(se.vocab.print(outactions[0])) print(se.vocab.print(outactions[1])) print(se.vocab.print(outactions[2])) self.assertTrue(se.vocab.print(outactions[0]) == texts[0]) self.assertTrue(se.vocab.print(outactions[1]) == texts[1]) self.assertTrue(se.vocab.print(outactions[2]) == texts[2])
def test_beam_search(self): texts = [ "i went to chocolate @END@", "awesome is @END@", "the meaning of life @END@" ] from parseq.vocab import SequenceEncoder se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) x.start_decoding() class Model(TransitionModel): def forward(self, x: BasicDecoderState): outprobs = torch.randn(len(x), x.query_encoder.vocab.number_of_ids()) outprobs = torch.nn.functional.log_softmax(outprobs, -1) return outprobs, x model = Model() beamsize = 50 maxtime = 10 bs = BeamDecoder(model, eval=[CELoss(ignore_index=0), SeqAccuracies()], eval_beam=[BeamSeqAccuracies()], beamsize=beamsize, maxtime=maxtime) y = bs(x) print(y)
def test_free_decoder(self): texts = [ "i went to chocolate a b c d e f g h i j k l m n o p q r @END@", "awesome is @END@", "the meaning of life @END@" ] se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() texts = ["@END@"] * 100 x = BasicDecoderState(texts, texts, se, se) class Model(TransitionModel): def forward(self, x: BasicDecoderState): outprobs = torch.rand(len(x), x.query_encoder.vocab.number_of_ids()) return outprobs, x MAXTIME = 10 dec = SeqDecoder(FreerunningTransition(Model(), maxtime=MAXTIME)) y = dec(x) print(y[1].followed_actions) print(max([len(y[1].followed_actions[i]) for i in range(len(y[1]))])) print(min([len(y[1].followed_actions[i]) for i in range(len(y[1]))])) self.assertTrue( max([len(y[1].followed_actions[i]) for i in range(len(y[1]))]) <= MAXTIME + 1)
def load_ds(domain="restaurants", min_freq=0, top_k=np.infty, nl_mode="bart-large", trainonvalid=False): ds = OvernightDatasetLoader(simplify_mode="light").load( domain=domain, trainonvalid=trainonvalid) seqenc_vocab = Vocab(padid=1, startid=0, endid=2, unkid=UNKID) seqenc = SequenceEncoder(vocab=seqenc_vocab, tokenizer=tree_to_lisp_tokens, add_start_token=True, add_end_token=True) for example in ds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=example[2] == "train") seqenc.finalize_vocab(min_freq=min_freq, top_k=top_k) nl_tokenizer = AutoTokenizer.from_pretrained(nl_mode) def tokenize(x): ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0], seqenc.convert(x[1], return_what="tensor"), x[2], x[0], x[1]) return ret tds, vds, xds = ds[(None, None, "train")].map(tokenize), \ ds[(None, None, "valid")].map(tokenize), \ ds[(None, None, "test")].map(tokenize) return tds, vds, xds, nl_tokenizer, seqenc
def test_beam_transition(self): texts = [ "i went to chocolate @END@", "awesome is @END@", "the meaning of life @END@" ] from parseq.vocab import SequenceEncoder se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) x.start_decoding() class Model(TransitionModel): def forward(self, x: BasicDecoderState): outprobs = torch.randn(len(x), x.query_encoder.vocab.number_of_ids()) outprobs = torch.nn.functional.log_softmax(outprobs, -1) return outprobs, x model = Model() beamsize = 50 maxtime = 10 beam_xs = [ x.make_copy(detach=False, deep=True) for _ in range(beamsize) ] beam_states = BeamState(beam_xs) print(len(beam_xs)) print(len(beam_states)) bt = BeamTransition(model, beamsize, maxtime=maxtime) i = 0 _, _, y, _ = bt(x, i) i += 1 _, _, y, _ = bt(y, i) all_terminated = y.all_terminated() while not all_terminated: _, predactions, y, all_terminated = bt(y, i) i += 1 print("timesteps done:") print(i) print(y) print(predactions[0]) for i in range(beamsize): print("-") # print(y.bstates[0].get(i).followed_actions) # print(predactions[0, i, :]) pa = predactions[0, i, :] # print((pa == se.vocab[se.vocab.endtoken]).cumsum(0)) pa = ((pa == se.vocab[se.vocab.endtoken]).long().cumsum(0) < 1).long() * pa yb = y.bstates[0].get(i).followed_actions[0, :] yb = yb * (yb != se.vocab[se.vocab.endtoken]).long() print(pa) print(yb) self.assertTrue(torch.allclose(pa, yb))
def test_beam_search_vs_greedy(self): with torch.no_grad(): texts = ["a b"] * 10 from parseq.vocab import SequenceEncoder se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) x.start_decoding() class Model(TransitionModel): transition_tensor = torch.tensor([[0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .51, .49], [0, 0, 0, 0, .01, .99]]) def forward(self, x: BasicDecoderState): prev = x.prev_actions outprobs = self.transition_tensor[prev] outprobs = torch.log(outprobs) return outprobs, x model = Model() beamsize = 50 maxtime = 10 beam_xs = [ x.make_copy(detach=False, deep=True) for _ in range(beamsize) ] beam_states = BeamState(beam_xs) print(len(beam_xs)) print(len(beam_states)) bt = BeamTransition(model, beamsize, maxtime=maxtime) i = 0 _, _, y, _ = bt(x, i) i += 1 _, _, y, _ = bt(y, i) all_terminated = y.all_terminated() while not all_terminated: start_time = time.time() _, _, y, all_terminated = bt(y, i) i += 1 # print(i) end_time = time.time() print(f"{i}: {end_time - start_time}") print(y) print(y.bstates.get(0).followed_actions)
class GeoQueryDatasetSub(GeoQueryDatasetFunQL): def __init__(self, p="../../datasets/geo880dong/", sentence_encoder: SequenceEncoder = None, min_freq: int = 2, **kw): super(GeoQueryDatasetSub, self).__init__(p, sentence_encoder, min_freq, **kw) def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int): self.data = {} self.sentence_encoder = sentence_encoder trainlines = [ x.strip() for x in open(os.path.join(p, "train.txt"), "r").readlines() ] testlines = [ x.strip() for x in open(os.path.join(p, "test.txt"), "r").readlines() ] splits = ["train"] * len(trainlines) + ["test"] * len(testlines) questions, queries = zip(*[x.split("\t") for x in trainlines]) testqs, testxs = zip(*[x.split("\t") for x in testlines]) questions += testqs queries += testxs queries = self.lisp2prolog(queries) self.query_encoder = SequenceEncoder(tokenizer=partial( basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split == "train") self.query_encoder.inc_build_vocab(query, seen=split == "train") for word, wordid in self.sentence_encoder.vocab.D.items(): self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq) self.query_encoder.finalize_vocab(min_freq=min_freq) self.build_data(questions, queries, splits) def lisp2prolog(self, data: List[str]): ret = [] for x in data: pas = lisp_to_pas(x) prolog = pas_to_prolog(pas) ret.append(prolog) return ret
class ConditionalRecallDataset(object): def __init__(self, maxlen=10, NperY=10, **kw): super(ConditionalRecallDataset, self).__init__(**kw) self.data = {} self.NperY, self.maxlen = NperY, maxlen self._seqs, self._ys = gen_data(self.maxlen, self.NperY) self.encoder = SequenceEncoder(tokenizer=lambda x: list(x)) for seq, y in zip(self._seqs, self._ys): self.encoder.inc_build_vocab(seq) self.encoder.inc_build_vocab(y) self.N = len(self._seqs) N = self.N splits = ["train"] * int(N * 0.8) + ["valid"] * int( N * 0.1) + ["test"] * int(N * 0.1) random.shuffle(splits) self.encoder.finalize_vocab() self.build_data(self._seqs, self._ys, splits) def build_data(self, seqs, ys, splits): for seq, y, split in zip(seqs, ys, splits): seq_tensor = self.encoder.convert(seq, return_what="tensor") y_tensor = self.encoder.convert(y, return_what="tensor") if split not in self.data: self.data[split] = [] self.data[split].append((seq_tensor[0], y_tensor[0][0])) def get_split(self, split: str): return DatasetSplitProxy(self.data[split]) def dataloader(self, split: str = None, batsize: int = 5, shuffle=None): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split, shuffle=shuffle) return ret else: assert (split in self.data.keys()) shuffle = shuffle if shuffle is not None else split in ( "train", "train+valid") dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=shuffle) return dl
def test_decoder_API(self): texts = ["i went to chocolate", "awesome is", "the meaning of life"] se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) print(x.inp_tensor) print("terminated") print(x.is_terminated()) print(x.all_terminated()) print("prev_actions") x.start_decoding() print(x.prev_actions) print("step") x.step(["i", torch.tensor([7]), "the"]) print(x.prev_actions) print(x.followed_actions)
def try_tokenizer_dataset(): from transformers import BartTokenizer ovd = OvernightDatasetLoader().load() seqenc = SequenceEncoder(tokenizer=tree_to_lisp_tokens) for example in ovd.examples: query = example[1] seqenc.inc_build_vocab(query, seen=example[2] == "train") seqenc.finalize_vocab() nl_tokenizer = BartTokenizer.from_pretrained("bart-large") def tokenize(x): ret = [xe for xe in x] ret.append(nl_tokenizer.tokenize(ret[0])) ret.append(nl_tokenizer.encode(ret[0], return_tensors="pt")) ret.append(seqenc.convert(ret[1], return_what="tensor")[0][None]) return ret ovd = ovd.map(tokenize) print(ovd[0])
def test_tf_decoder_with_losses_with_gold(self): texts = [ "i went to chocolate @END@", "awesome is @END@", "the meaning of life @END@" ] se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) class Model(TransitionModel): def forward(self, x: BasicDecoderState): outprobs = torch.zeros(len(x), x.query_encoder.vocab.number_of_ids()) golds = x.get_gold().gather( 1, torch.tensor(x._timesteps).to(torch.long)[:, None]) outprobs.scatter_(1, golds, 1) return outprobs, x celoss = CELoss(ignore_index=0) accs = SeqAccuracies() dec = SeqDecoder(TFTransition(Model()), eval=[celoss, accs]) y = dec(x) print(y[0]) print(y[1].followed_actions) print(y[1].get_gold()) self.assertEqual(y[0]["seq_acc"], 1) self.assertEqual(y[0]["elem_acc"], 1) # print(y[1].followed_actions) outactions = y[1].followed_actions.detach().cpu().numpy() # print(outactions[0]) # print(se.vocab.print(outactions[0])) # print(se.vocab.print(outactions[1])) # print(se.vocab.print(outactions[2])) self.assertTrue(se.vocab.print(outactions[0]) == texts[0]) self.assertTrue(se.vocab.print(outactions[1]) == texts[1]) self.assertTrue(se.vocab.print(outactions[2]) == texts[2])
def test_create(self): se = SequenceEncoder(tokenizer=lambda x: x.split()) texts = [ "i went to chocolate", "awesome is @PAD@ @PAD@", "the meaning of life" ] for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = [BasicDecoderState([t], [t], se, se) for t in texts] merged_x = x[0].merge(x) texts = ["i went to chocolate", "awesome is", "the meaning of life"] batch_x = BasicDecoderState(texts, texts, se, se) print(merged_x.inp_tensor) print(batch_x.inp_tensor) self.assertTrue(torch.allclose(merged_x.inp_tensor, batch_x.inp_tensor)) self.assertTrue( torch.allclose(merged_x.gold_tensor, batch_x.gold_tensor))
def test_tf_decoder_with_losses(self): texts = [ "i went to chocolate @END@", "awesome is @END@", "the meaning of life @END@" ] se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) class Model(TransitionModel): def forward(self, x: BasicDecoderState): outprobs = torch.rand(len(x), x.query_encoder.vocab.number_of_ids()) outprobs = torch.nn.functional.log_softmax(outprobs, -1) return outprobs, x celoss = CELoss(ignore_index=0) accs = SeqAccuracies() dec = SeqDecoder(TFTransition(Model()), eval=[celoss, accs]) y = dec(x) print(y[0]) print(y[1].followed_actions) print(y[1].get_gold()) # print(y[1].followed_actions) outactions = y[1].followed_actions.detach().cpu().numpy() # print(outactions[0]) # print(se.vocab.print(outactions[0])) # print(se.vocab.print(outactions[1])) # print(se.vocab.print(outactions[2])) self.assertTrue(se.vocab.print(outactions[0]) == texts[0]) self.assertTrue(se.vocab.print(outactions[1]) == texts[1]) self.assertTrue(se.vocab.print(outactions[2]) == texts[2])
class GeoDatasetRank(object): def __init__(self, p="geoquery_gen/run4/", min_freq: int = 2, splits=None, **kw): super(GeoDatasetRank, self).__init__(**kw) self._initialize(p) self.splits_proportions = splits def _initialize(self, p): self.data = {} with open(os.path.join(p, "trainpreds.json")) as f: trainpreds = ujson.load(f) with open(os.path.join(p, "testpreds.json")) as f: testpreds = ujson.load(f) splits = ["train"] * len(trainpreds) + ["test"] * len(testpreds) preds = trainpreds + testpreds self.sentence_encoder = SequenceEncoder(tokenizer=lambda x: x.split()) self.query_encoder = SequenceEncoder(tokenizer=lambda x: x.split()) # build vocabularies for i, (example, split) in enumerate(zip(preds, splits)): self.sentence_encoder.inc_build_vocab(" ".join( example["sentence"]), seen=split == "train") self.query_encoder.inc_build_vocab(" ".join(example["gold"]), seen=split == "train") for can in example["candidates"]: self.query_encoder.inc_build_vocab(" ".join(can["tokens"]), seen=False) # for word, wordid in self.sentence_encoder.vocab.D.items(): # self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab() self.query_encoder.finalize_vocab() self.build_data(preds, splits) def build_data(self, examples: Iterable[dict], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 for example, split in zip(examples, splits): inp, out = " ".join(example["sentence"]), " ".join(example["gold"]) inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(" ".join(example["gold"][:-1])) if not isinstance(gold_tree, Tree): assert (gold_tree is not None) gold_tensor, gold_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], [] candidate_align_entropies = [] candidate_trees = [] candidate_same = [] for cand in example["candidates"]: cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]), None) if cand_tree is None: cand_tree = Tree("@UNK@", []) assert (cand_tree is not None) cand_tensor, cand_tokens = self.query_encoder.convert( " ".join(cand["tokens"]), return_what="tensor,tokens") candidate_tensors.append(cand_tensor) candidate_tokens.append(cand_tokens) candidate_align_tensors.append(torch.tensor( cand["alignments"])) candidate_align_entropies.append( torch.tensor(cand["align_entropies"])) candidate_trees.append(cand_tree) candidate_same.append( are_equal_trees(cand_tree, gold_tree, orderless={"and", "or"}, unktoken="@NOUNKTOKENHERE@")) candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0), 0) candidate_align_tensor = torch.stack( q.pad_tensors(candidate_align_tensors, 0), 0) candidate_align_entropy = torch.stack( q.pad_tensors(candidate_align_entropies, 0), 0) candidate_same = torch.tensor(candidate_same) state = RankState( inp_tensor[None, :], gold_tensor[None, :], candidate_tensor[None, :, :], candidate_same[None, :], candidate_align_tensor[None, :], candidate_align_entropy[None, :], self.sentence_encoder.vocab, self.query_encoder.vocab, ) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, candidate_tensor.size(-1), gold_tensor.size(-1)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out def get_split(self, split: str): return DatasetSplitProxy(self.data[split]) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) goldmaxlen = max(goldmaxlen, state.candtensors.size(-1)) inp_tensors = q.pad_tensors([state.inp_tensor for state in data], 1, 0) gold_tensors = q.pad_tensors([state.gold_tensor for state in data], 1, 0) candtensors = q.pad_tensors([state.candtensors for state in data], 2, 0) alignments = q.pad_tensors([state.alignments for state in data], 2, 0) alignment_entropies = q.pad_tensors( [state.alignment_entropies for state in data], 2, 0) for i, state in enumerate(data): state.inp_tensor = inp_tensors[i] state.gold_tensor = gold_tensors[i] state.candtensors = candtensors[i] state.alignments = alignments[i] state.alignment_entropies = alignment_entropies[i] ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5, shuffle=None): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split, shuffle=shuffle) return ret else: assert (split in self.data.keys()) shuffle = shuffle if shuffle is not None else split in ( "train", "train+valid") dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=shuffle, collate_fn=type(self).collate_fn) return dl
class LCQuaDnoENTDataset(object): def __init__(self, p="../../datasets/lcquad/", sentence_encoder: SequenceEncoder = None, min_freq: int = 2, splits=None, **kw): super(LCQuaDnoENTDataset, self).__init__(**kw) self._simplify_filters = True # if True, filter expressions are converted to orderless and-expressions self._initialize(p, sentence_encoder, min_freq) self.splits_proportions = splits def lines_to_examples(self, lines: List[str]): maxsize_before = 0 avgsize_before = [] maxsize_after = 0 avgsize_after = [] afterstring = set() def convert_to_lispstr(_x): splits = _x.split() assert (sum([1 if xe == "~" else 0 for xe in splits]) == 1) assert (splits[1] == "~") splits = ["," if xe == "&" else xe for xe in splits] pstr = f"{splits[0]} ({' '.join(splits[2:])})" return pstr ret = [] ltp = None j = 0 for i, line in enumerate(lines): question = line["question"] query = line["logical_form"] query = convert_to_lispstr(query) z, ltp = prolog_to_pas(query, ltp) if z is not None: ztree = pas_to_tree(z) maxsize_before = max(maxsize_before, tree_size(ztree)) avgsize_before.append(tree_size(ztree)) lf = ztree ret.append((question, lf)) # print(f"Example {j}:") # print(ret[-1][0]) # print(ret[-1][1]) # print() ltp = None maxsize_after = max(maxsize_after, tree_size(lf)) avgsize_after.append(tree_size(lf)) j += 1 avgsize_before = sum(avgsize_before) / len(avgsize_before) avgsize_after = sum(avgsize_after) / len(avgsize_after) print("Sizes ({j} examples):") # print(f"\t Max, Avg size before: {maxsize_before}, {avgsize_before}") print(f"\t Max, Avg size: {maxsize_after}, {avgsize_after}") return ret def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int): self.data = {} self.sentence_encoder = sentence_encoder jp = os.path.join(p, "lcquad_dataset.json") with open(jp, "r") as f: examples = ujson.load(f) examples = self.lines_to_examples(examples) questions, queries = tuple(zip(*examples)) trainlen = int(round(0.8 * len(examples))) validlen = int(round(0.1 * len(examples))) testlen = int(round(0.1 * len(examples))) splits = ["train"] * trainlen + ["valid"] * validlen + ["test" ] * testlen random.seed(1337) random.shuffle(splits) assert (len(splits) == len(examples)) self.query_encoder = SequenceEncoder(tokenizer=partial( tree_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split == "train") self.query_encoder.inc_build_vocab(query, seen=split == "train") for word, wordid in self.sentence_encoder.vocab.D.items(): self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq) self.query_encoder.finalize_vocab(min_freq=min_freq) self.build_data(questions, queries, splits) def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 eid = 0 gold_map = torch.arange( 0, self.query_encoder.vocab.number_of_ids(last_nonrare=False)) rare_tokens = self.query_encoder.vocab.rare_tokens - set( self.sentence_encoder.vocab.D.keys()) for rare_token in rare_tokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") out_tensor, out_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [out], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab) state.eids = np.asarray([eid], dtype="int64") maxlen_in, maxlen_out = max(maxlen_in, len(state.inp_tokens[0])), max( maxlen_out, len(state.gold_tokens[0])) if split not in self.data: self.data[split] = [] self.data[split].append(state) eid += 1 self.maxlen_input, self.maxlen_output = maxlen_in, maxlen_out def get_split(self, split: str): splits = split.split("+") data = [] for split in splits: data += self.data[split] return DatasetSplitProxy(data) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split) return ret else: dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split in ("train", "train+valid"), collate_fn=type(self).collate_fn) return dl
def load_ds(traindomains=("restaurants",), testdomain="housing", min_freq=1, mincoverage=1, top_k=np.infty, nl_mode="bert-base-uncased", fullsimplify=False, onlyabstract=False, pretrainsetting="all+lex", # "all", "lex" or "all+lex" finetunesetting="lex", # "lex", "all", "min" ): """ :param traindomains: :param testdomain: :param min_freq: :param mincoverage: :param top_k: :param nl_mode: :param fullsimplify: :param add_domain_start: :param onlyabstract: :param pretrainsetting: "all": use all examples from every domain "lex": use only lexical examples "all+lex": use both :param finetunesetting: "lex": use lexical examples "all": use all training examples "min": use minimal lexicon-covering set of examples ! Test is always over the same original test set. ! Validation is over a fraction of training data :return: """ general_tokens = { "(", ")", "arg:~type", "arg:type", "op:and", "SW:concat", "cond:has", "arg:<=", "arg:<", "arg:>=", "arg:>", "arg:!=", "arg:=", "SW:superlative", "SW:CNT-arg:min", "SW:CNT-arg:<", "SW:CNT-arg:<=", "SW:CNT-arg:>=", "SW:CNT-arg:>", "SW:CNT-arg:max", "SW:CNT-arg:=", "arg:max", } def tokenize_and_add_start(t): tokens = tree_to_lisp_tokens(t) starttok = "@START@" tokens = [starttok] + tokens return tokens sourceex = [] for traindomain in traindomains: ds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True, restore_reverse=DATA_RESTORE_REVERSE, validfrac=.10)\ .load(domain=traindomain) sourceex += ds[(None, None, lambda x: x in ("train", "valid", "lexicon"))].map(lambda x: (x[0], x[1], x[2], traindomain)).examples # don't use test examples testds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True, restore_reverse=DATA_RESTORE_REVERSE)\ .load(domain=testdomain) targetex = testds.map(lambda x: x + (testdomain,)).examples pretrainex = [] if "all" in pretrainsetting.split("+"): pretrainex += [(a, tokenize_and_add_start(b), "pretrain", d) for a, b, c, d in sourceex if c == "train"] if "lex" in pretrainsetting.split("+"): pretrainex += [(a, tokenize_and_add_start(b), "pretrain", d) for a, b, c, d in sourceex if c == "lexicon"] pretrainvalidex = [(a, tokenize_and_add_start(b), "pretrainvalid", d) for a, b, c, d in sourceex if c == "valid"] if finetunesetting == "all": finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in targetex if c == "train"] elif finetunesetting == "lex": finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in targetex if c == "lexicon"] elif finetunesetting == "min": finetunetrainex = get_maximum_spanning_examples([(a, b, c, d) for a, b, c, d in targetex if c == "train"], mincoverage=mincoverage, loadedex=[e for e in pretrainex if e[2] == "pretrain"]) finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in finetunetrainex] finetunevalidex = [(a, tokenize_and_add_start(b), "ftvalid", d) for a, b, c, d in targetex if c == "valid"] finetunetestex = [(a, tokenize_and_add_start(b), "fttest", d) for a, b, c, d in targetex if c == "test"] print(f"Using mode \"{finetunesetting}\" for finetuning data: " f"\n\t{len(finetunetrainex)} training examples") allex = pretrainex + pretrainvalidex + finetunetrainex + finetunevalidex + finetunetestex ds = Dataset(allex) if onlyabstract: et = get_lf_abstract_transform(ds[lambda x: x[3] != testdomain].examples) ds = ds.map(lambda x: (x[0], et(x[1]), x[2], x[3])) seqenc_vocab = Vocab(padid=0, startid=1, endid=2, unkid=UNKID) seqenc = SequenceEncoder(vocab=seqenc_vocab, tokenizer=lambda x: x, add_start_token=False, add_end_token=True) for example in ds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=example[2] in ("pretrain", "fttrain")) seqenc.finalize_vocab(min_freq=min_freq, top_k=top_k) generaltokenmask = torch.zeros(seqenc_vocab.number_of_ids(), dtype=torch.long) for token, tokenid in seqenc_vocab.D.items(): if token in general_tokens: generaltokenmask[tokenid] = 1 nl_tokenizer = AutoTokenizer.from_pretrained(nl_mode) def tokenize(x): ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0], seqenc.convert(x[1], return_what="tensor"), x[2], x[0], x[1], x[3]) return ret tds, ftds, vds, fvds, xds = ds[(None, None, "pretrain", None)].map(tokenize), \ ds[(None, None, "fttrain", None)].map(tokenize), \ ds[(None, None, "pretrainvalid", None)].map(tokenize), \ ds[(None, None, "ftvalid", None)].map(tokenize), \ ds[(None, None, "fttest", None)].map(tokenize) return tds, ftds, vds, fvds, xds, nl_tokenizer, seqenc, generaltokenmask
def load_ds(domain="restaurants", nl_mode="bert-base-uncased", trainonvalid=False, noreorder=False): """ Creates a dataset of examples which have * NL question and tensor * original FL tree * reduced FL tree with slots (this is randomly generated) * tensor corresponding to reduced FL tree with slots * mask specifying which elements in reduced FL tree are terminated * 2D gold that specifies whether a token/action is in gold for every position (compatibility with MML!) """ orderless = {"op:and", "SW:concat"} # only use in eval!! ds = OvernightDatasetLoader().load(domain=domain, trainonvalid=trainonvalid) ds = ds.map(lambda x: (x[0], ATree("@START@", [x[1]]), x[2])) if not noreorder: ds = ds.map(lambda x: (x[0], reorder_tree(x[1], orderless=orderless), x[2])) vocab = Vocab(padid=0, startid=2, endid=3, unkid=1) vocab.add_token("@START@", seen=np.infty) vocab.add_token( "@CLOSE@", seen=np.infty ) # only here for the action of closing an open position, will not be seen at input vocab.add_token( "@OPEN@", seen=np.infty ) # only here for the action of opening a closed position, will not be seen at input vocab.add_token( "@REMOVE@", seen=np.infty ) # only here for deletion operations, won't be seen at input vocab.add_token( "@REMOVESUBTREE@", seen=np.infty ) # only here for deletion operations, won't be seen at input vocab.add_token("@SLOT@", seen=np.infty) # will be seen at input, can't be produced! nl_tokenizer = BertTokenizer.from_pretrained(nl_mode) # for tok, idd in nl_tokenizer.vocab.items(): # vocab.add_token(tok, seen=np.infty) # all wordpieces are added for possible later generation tds, vds, xds = ds[lambda x: x[2] == "train"], \ ds[lambda x: x[2] == "valid"], \ ds[lambda x: x[2] == "test"] seqenc = SequenceEncoder( vocab=vocab, tokenizer=lambda x: extract_info(x, onlytokens=True), add_start_token=False, add_end_token=False) for example in tds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=True) for example in vds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=False) for example in xds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=False) seqenc.finalize_vocab(min_freq=0) def mapper(x): nl = x[0] fl = x[1] fltoks = extract_info(fl, onlytokens=True) seq = seqenc.convert(fltoks, return_what="tensor") ret = (nl_tokenizer.encode(nl, return_tensors="pt")[0], seq) return ret tds_seq = tds.map(mapper) vds_seq = vds.map(mapper) xds_seq = xds.map(mapper) return tds_seq, vds_seq, xds_seq, nl_tokenizer, seqenc, orderless
class GeoDataset(object): def __init__(self, p="../../datasets/geo880dong/", sentence_encoder:SequenceEncoder=None, min_freq:int=2, splits=None, **kw): super(GeoDataset, self).__init__(**kw) self._initialize(p, sentence_encoder, min_freq) self.splits_proportions = splits def _initialize(self, p, sentence_encoder:SequenceEncoder, min_freq:int): self.data = {} self.sentence_encoder = sentence_encoder trainlines = [x.strip() for x in open(os.path.join(p, "train.txt"), "r").readlines()] testlines = [x.strip() for x in open(os.path.join(p, "test.txt"), "r").readlines()] splits = ["train"]*len(trainlines) + ["test"] * len(testlines) questions, queries = zip(*[x.split("\t") for x in trainlines]) testqs, testxs = zip(*[x.split("\t") for x in testlines]) questions += testqs queries += testxs self.query_encoder = SequenceEncoder(tokenizer=partial(basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split=="train") self.query_encoder.inc_build_vocab(query, seen=split=="train") # for word, wordid in self.sentence_encoder.vocab.D.items(): # self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True) self.query_encoder.finalize_vocab(min_freq=min_freq) self.build_data(questions, queries, splits) def build_data(self, inputs:Iterable[str], outputs:Iterable[str], splits:Iterable[str], unktokens:Set[str]=None): gold_map = None maxlen_in, maxlen_out = 0, 0 if unktokens is not None: gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids(last_nonrare=False)) for rare_token in unktokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert(inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(out) assert(gold_tree is not None) out_tensor, out_tokens = self.query_encoder.convert(out, return_what="tensor,tokens") if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, len(out_tensor)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out def get_split(self, split:str): return DatasetSplitProxy(self.data[split]) @staticmethod def collate_fn(data:Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros(1, goldmaxlen - state.gold_tensor.size(1))], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros(1, inpmaxlen - state.inp_tensor.size(1))], 1) ret = data[0].merge(data) return ret def dataloader(self, split:str=None, batsize:int=5): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split) return ret else: assert(split in self.data.keys()) dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split=="train", collate_fn=type(self).collate_fn) return dl
class GeoDataset(object): def __init__(self, p="../../datasets/geo880dong/", sentence_encoder: SequenceEncoder = None, min_freq: int = 2, cvfolds=None, testfold=None, reorder_random=False, **kw): super(GeoDataset, self).__init__(**kw) self.cvfolds, self.testfold = cvfolds, testfold self.reorder_random = reorder_random self._initialize(p, sentence_encoder, min_freq) def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int): self.data = {} self.sentence_encoder = sentence_encoder trainlines = [ x.strip() for x in open(os.path.join(p, "train.txt"), "r").readlines() ] testlines = [ x.strip() for x in open(os.path.join(p, "test.txt"), "r").readlines() ] if self.cvfolds is None: splits = ["train"] * len(trainlines) + ["test"] * len(testlines) else: cvsplit_len = len(trainlines) / self.cvfolds splits = [] for i in range(0, self.cvfolds): splits += [i] * round(cvsplit_len * (i + 1) - len(splits)) random.shuffle(splits) splits = [ "valid" if x == self.testfold else "train" for x in splits ] splits = splits + ["test"] * len(testlines) questions, queries = zip(*[x.split("\t") for x in trainlines]) testqs, testxs = zip(*[x.split("\t") for x in testlines]) questions += testqs queries += testxs self.query_encoder = SequenceEncoder(tokenizer=partial( basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split == "train") self.query_encoder.inc_build_vocab(query, seen=split == "train") # for word, wordid in self.sentence_encoder.vocab.D.items(): # self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True) self.query_encoder.finalize_vocab(min_freq=min_freq) token_specs = self.build_token_specs(queries) self.token_specs = token_specs self.build_data(questions, queries, splits) def build_token_specs(self, outputs: Iterable[str]): token_specs = dict() def walk_the_tree(t, _ts): l = t.label() if l not in _ts: _ts[l] = [np.infty, -np.infty] minc, maxc = _ts[l] _ts[l] = [min(minc, len(t)), max(maxc, len(t))] for c in t: walk_the_tree(c, _ts) for out in outputs: out_tokens = self.query_encoder.convert(out, return_what="tokens")[0] assert (out_tokens[-1] == "@END@") out_tokens = out_tokens[:-1] out_str = " ".join(out_tokens) tree = lisp_to_tree(out_str) walk_the_tree(tree, token_specs) token_specs["and"][1] = np.infty return token_specs def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str], unktokens: Set[str] = None): gold_map = None maxlen_in, maxlen_out = 0, 0 if unktokens is not None: gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids()) for rare_token in unktokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(out) assert (gold_tree is not None) out_tensor, out_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab, token_specs=self.token_specs) if split == "train" and self.reorder_random is True: gold_tree_ = tensor2tree(out_tensor, self.query_encoder.vocab) random_gold_tree = random.choice( get_tree_permutations(gold_tree_, orderless={"and"})) out_ = tree_to_lisp(random_gold_tree) out_tensor_, out_tokens_ = self.query_encoder.convert( out_, return_what="tensor,tokens") if gold_map is not None: out_tensor_ = gold_map[out_tensor_] state.gold_tensor = out_tensor_[None] if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, len(out_tensor)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out def get_split(self, split: str): data = [] for split_e in split.split("+"): data += self.data[split_e] return DatasetSplitProxy(data) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5, shuffle=None): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split, shuffle=shuffle) return ret else: # assert(split in self.data.keys()) shuffle = shuffle if shuffle is not None else split in ( "train", "train+valid") dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=shuffle, collate_fn=type(self).collate_fn) return dl
class GeoDataset(object): def __init__(self, p="../../datasets/geo880_multiling/geoquery/", train_lang="en", test_lang=None, bert_tokenizer=None, min_freq: int = 2, cvfolds=None, testfold=None, **kw): super(GeoDataset, self).__init__(**kw) self.train_lang = train_lang self.test_lang = test_lang if test_lang is not None else train_lang self.cvfolds, self.testfold = cvfolds, testfold self._initialize(p, bert_tokenizer, min_freq) def _initialize(self, p, bert_tokenizer, min_freq: int): self.data = {} self.bert_vocab = Vocab() self.bert_vocab.set_dict(bert_tokenizer.vocab) self.sentence_encoder = SequenceEncoder( lambda x: bert_tokenizer.tokenize(f"[CLS] {x} [SEP]"), vocab=self.bert_vocab) trainlines = [ x for x in ujson.load( open(os.path.join(p, f"geo-{self.train_lang}.json"), "r")) ] testlines = [ x for x in ujson.load( open(os.path.join(p, f"geo-{self.train_lang}.json"), "r")) ] trainlines = [x for x in trainlines if x["split"] == "train"] testlines = [x for x in testlines if x["split"] == "test"] if self.cvfolds is None: splits = ["train"] * len(trainlines) + ["test"] * len(testlines) else: cvsplit_len = len(trainlines) / self.cvfolds splits = [] for i in range(0, self.cvfolds): splits += [i] * round(cvsplit_len * (i + 1) - len(splits)) random.shuffle(splits) splits = [ "valid" if x == self.testfold else "train" for x in splits ] splits = splits + ["test"] * len(testlines) questions = [x["nl"] for x in trainlines] queries = [x["mrl"] for x in trainlines] xquestions = [x["nl"] for x in testlines] xqueries = [x["mrl"] for x in testlines] questions += xquestions queries += xqueries # initialize output vocabulary outvocab = Vocab() for token, bertid in self.bert_vocab.D.items(): outvocab.add_token(token, seen=False) self.query_encoder = SequenceEncoder(tokenizer=partial( basic_query_tokenizer, strtok=bert_tokenizer), vocab=outvocab, add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.query_encoder.inc_build_vocab(query, seen=split == "train") keeptokens = set(self.bert_vocab.D.keys()) self.query_encoder.finalize_vocab(min_freq=min_freq, keep_tokens=keeptokens) token_specs = self.build_token_specs(queries) self.token_specs = token_specs self.build_data(questions, queries, splits) def build_token_specs(self, outputs: Iterable[str]): token_specs = dict() def walk_the_tree(t, _ts): l = t.label() if l not in _ts: _ts[l] = [np.infty, -np.infty] minc, maxc = _ts[l] _ts[l] = [min(minc, len(t)), max(maxc, len(t))] for c in t: walk_the_tree(c, _ts) for out in outputs: out_tokens = self.query_encoder.convert(out, return_what="tokens")[0] assert (out_tokens[-1] == "@END@") out_tokens = out_tokens[:-1] out_str = " ".join(out_tokens) tree = lisp_to_tree(out_str) walk_the_tree(tree, token_specs) # token_specs["and"][1] = np.infty return token_specs def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 for inp, out, split in zip(inputs, outputs, splits): # tokenize both input and output inp_tokens = self.sentence_encoder.convert(inp, return_what="tokens")[0] out_tokens = self.query_encoder.convert(out, return_what="tokens")[0] # get gold tree gold_tree = lisp_to_tree(" ".join(out_tokens[:-1])) assert (gold_tree is not None) # replace words in output that can't be copied from given input to UNK tokens unktoken = self.query_encoder.vocab.unktoken inp_tokens_ = set(inp_tokens) out_tokens = [ out_token if out_token in inp_tokens_ or (out_token in self.query_encoder.vocab and not out_token in self.query_encoder.vocab.rare_tokens) else unktoken for out_token in out_tokens ] # convert token sequences to ids inp_tensor = self.sentence_encoder.convert(inp_tokens, return_what="tensor")[0] out_tensor = self.query_encoder.convert(out_tokens, return_what="tensor")[0] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab, token_specs=self.token_specs) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, len(out_tensor)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out def get_split(self, split: str): data = [] for split_e in split.split("+"): data += self.data[split_e] return DatasetSplitProxy(data) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5, shuffle=None): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split, shuffle=shuffle) return ret else: # assert(split in self.data.keys()) shuffle = shuffle if shuffle is not None else split in ( "train", "train+valid") dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=shuffle, collate_fn=type(self).collate_fn) return dl
def load_ds(traindomains=("restaurants", ), testdomain="housing", min_freq=1, mincoverage=1, top_k=np.infty, nl_mode="bert-base-uncased", fullsimplify=False, add_domain_start=True, useall=False): def tokenize_and_add_start(t, _domain): tokens = tree_to_lisp_tokens(t) starttok = f"@START/{_domain}@" if add_domain_start else "@START@" tokens = [starttok] + tokens return tokens allex = [] for traindomain in traindomains: ds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True, restore_reverse=DATA_RESTORE_REVERSE, validfrac=.10)\ .load(domain=traindomain) allex += ds[(None, None, lambda x: x in ("train", "valid"))].map(lambda x: (x[0], x[1], x[ 2], traindomain)).examples # don't use test examples testds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True, restore_reverse=DATA_RESTORE_REVERSE)\ .load(domain=testdomain) if useall: print("using all training examples") sortedexamples = testds[(None, None, "train")].examples else: sortedexamples = get_maximum_spanning_examples( testds[(None, None, "train")].examples, mincoverage=mincoverage, loadedex=[e for e in allex if e[2] == "train"]) allex += testds[( None, None, "valid")].map(lambda x: (x[0], x[1], "ftvalid", testdomain)).examples allex += testds[( None, None, "test")].map(lambda x: (x[0], x[1], x[2], testdomain)).examples allex += [(ex[0], ex[1], "fttrain", testdomain) for ex in sortedexamples] _ds = Dataset(allex) ds = _ds.map(lambda x: (x[0], tokenize_and_add_start(x[1], x[3]), x[2], x[3])) et = get_lf_abstract_transform(ds[lambda x: x[3] != testdomain].examples) ds = ds.map(lambda x: (x[0], et(x[1]), x[1], x[2], x[3])) seqenc_vocab = Vocab(padid=0, startid=1, endid=2, unkid=UNKID) absseqenc_vocab = Vocab(padid=0, startid=1, endid=2, unkid=UNKID) absseqenc = SequenceEncoder(vocab=seqenc_vocab, tokenizer=lambda x: x, add_start_token=False, add_end_token=True) fullseqenc = SequenceEncoder(vocab=absseqenc_vocab, tokenizer=lambda x: x, add_start_token=False, add_end_token=True) for example in ds.examples: absseqenc.inc_build_vocab(example[1], seen=example[3] in ("train", "fttrain")) fullseqenc.inc_build_vocab(example[2], seen=example[3] in ("train", "fttrain")) absseqenc.finalize_vocab(min_freq=min_freq, top_k=top_k) fullseqenc.finalize_vocab(min_freq=min_freq, top_k=top_k) nl_tokenizer = AutoTokenizer.from_pretrained(nl_mode) def tokenize(x): ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0], absseqenc.convert(x[1], return_what="tensor"), fullseqenc.convert(x[2], return_what="tensor"), x[3], x[0], x[1], x[4]) return ret tds, ftds, vds, fvds, xds = ds[(None, None, None, "train", None)].map(tokenize), \ ds[(None, None, None, "fttrain", None)].map(tokenize), \ ds[(None, None, None, "valid", None)].map(tokenize), \ ds[(None, None, None, "ftvalid", None)].map(tokenize), \ ds[(None, None, None, "test", None)].map(tokenize) return tds, ftds, vds, fvds, xds, nl_tokenizer, fullseqenc, absseqenc
def test_beam_search_stored_probs(self): with torch.no_grad(): texts = ["a b"] * 2 from parseq.vocab import SequenceEncoder se = SequenceEncoder(tokenizer=lambda x: x.split()) for t in texts: se.inc_build_vocab(t) se.finalize_vocab() x = BasicDecoderState(texts, texts, se, se) x.start_decoding() class Model(TransitionModel): transition_tensor = torch.tensor([[0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .51, .48], [0, 0, 0, 0.01, .01, .98]]) def forward(self, x: BasicDecoderState): prev = x.prev_actions outprobs = self.transition_tensor[prev] outprobs = torch.log(outprobs) outprobs -= 0.01 * torch.rand_like(outprobs) return outprobs, x model = Model() beamsize = 3 maxtime = 10 beam_xs = [ x.make_copy(detach=False, deep=True) for _ in range(beamsize) ] beam_states = BeamState(beam_xs) print(len(beam_xs)) print(len(beam_states)) bt = BeamTransition(model, beamsize, maxtime=maxtime) i = 0 _, _, y, _ = bt(x, i) i += 1 _, _, y, _ = bt(y, i) all_terminated = False while not all_terminated: start_time = time.time() _, _, y, all_terminated = bt(y, i) i += 1 # print(i) end_time = time.time() print(f"{i}: {end_time - start_time}") print(y) print(y.bstates.get(0).followed_actions) best_actions = y.bstates.get(0).followed_actions best_actionprobs = y.actionprobs.get(0) for i in range(len(best_actions)): print(i) i_prob = 0 for j in range(len(best_actions[i])): action_id = best_actions[i, j] action_prob = best_actionprobs[i, j, action_id] i_prob += action_prob print(i_prob) print(y.bscores[i, 0]) self.assertTrue(torch.allclose(i_prob, y.bscores[i, 0]))
def load_ds(domain="restaurants", nl_mode="bert-base-uncased", trainonvalid=False, noreorder=False, numbered=False): """ Creates a dataset of examples which have * NL question and tensor * original FL tree * reduced FL tree with slots (this is randomly generated) * tensor corresponding to reduced FL tree with slots * mask specifying which elements in reduced FL tree are terminated * 2D gold that specifies whether a token/action is in gold for every position (compatibility with MML!) """ # orderless = {"op:and", "SW:concat"} # only use in eval!! orderless = ORDERLESS ds = OvernightDatasetLoader(simplify_mode="none").load( domain=domain, trainonvalid=trainonvalid) # ds contains 3-tuples of (input, output tree, split name) if not noreorder: ds = ds.map(lambda x: (x[0], reorder_tree(x[1], orderless=orderless), x[2])) ds = ds.map(lambda x: (x[0], tree_to_seq(x[1]), x[2])) if numbered: ds = ds.map(lambda x: (x[0], make_numbered_tokens(x[1]), x[2])) vocab = Vocab(padid=0, startid=2, endid=3, unkid=1) vocab.add_token("@BOS@", seen=np.infty) vocab.add_token("@EOS@", seen=np.infty) vocab.add_token("@STOP@", seen=np.infty) nl_tokenizer = BertTokenizer.from_pretrained(nl_mode) tds, vds, xds = ds[lambda x: x[2] == "train"], \ ds[lambda x: x[2] == "valid"], \ ds[lambda x: x[2] == "test"] seqenc = SequenceEncoder(vocab=vocab, tokenizer=lambda x: x, add_start_token=False, add_end_token=False) for example in tds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=True) for example in vds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=False) for example in xds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=False) seqenc.finalize_vocab(min_freq=0) def mapper(x): seq = seqenc.convert(x[1], return_what="tensor") ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0], seq) return ret tds_seq = tds.map(mapper) vds_seq = vds.map(mapper) xds_seq = xds.map(mapper) return tds_seq, vds_seq, xds_seq, nl_tokenizer, seqenc, orderless
class GeoQueryDataset(object): def __init__(self, p="../../datasets/geoquery/", sentence_encoder: SequenceEncoder = None, min_freq: int = 2, **kw): super(GeoQueryDataset, self).__init__(**kw) self.data = {} self.sentence_encoder = sentence_encoder questions = [ x.strip() for x in open(os.path.join(p, "questions.txt"), "r").readlines() ] queries = [ x.strip() for x in open(os.path.join(p, "queries.funql"), "r").readlines() ] trainidxs = set([ int(x.strip()) for x in open(os.path.join(p, "train_indexes.txt"), "r").readlines() ]) testidxs = set([ int(x.strip()) for x in open(os.path.join(p, "test_indexes.txt"), "r").readlines() ]) splits = [None] * len(questions) for trainidx in trainidxs: splits[trainidx] = "train" for testidx in testidxs: splits[testidx] = "test" if any([split == None for split in splits]): print( f"{len([split for split in splits if split == None])} examples not assigned to any split" ) self.query_encoder = SequenceEncoder(tokenizer=partial( basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split == "train") self.query_encoder.inc_build_vocab(query, seen=split == "train") self.sentence_encoder.finalize_vocab(min_freq=min_freq) self.query_encoder.finalize_vocab(min_freq=min_freq) self.build_data(questions, queries, splits) def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str]): for inp, out, split in zip(inputs, outputs, splits): state = BasicDecoderState([inp], [out], self.sentence_encoder, self.query_encoder) if split not in self.data: self.data[split] = [] self.data[split].append(state) def get_split(self, split: str): return DatasetSplitProxy(self.data[split]) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split) return ret else: assert (split in self.data.keys()) dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split == "train", collate_fn=GeoQueryDataset.collate_fn) return dl
class OvernightDataset(object): def __init__(self, p="../../datasets/overnightData/", pcache="../../datasets/overnightCache/", domain: str = "restaurants", sentence_encoder: SequenceEncoder = None, usecache=True, min_freq: int = 2, **kw): super(OvernightDataset, self).__init__(**kw) self._simplify_filters = True # if True, filter expressions are converted to orderless and-expressions self._pcache = pcache if usecache else None self._domain = domain self._usecache = usecache self._initialize(p, domain, sentence_encoder, min_freq) def lines_to_examples(self, lines: List[str]): maxsize_before = 0 avgsize_before = [] maxsize_after = 0 avgsize_after = [] afterstring = set() def simplify_tree(t: Tree): if t.label() == "call": assert (len(t[0]) == 0) # if not t[0].label().startswith("SW."): # print(t) # assert(t[0].label().startswith("SW.")) t.set_label(t[0].label()) del t[0] elif t.label() == "string": afterstring.update(set([tc.label() for tc in t])) assert (len(t) == 1) assert (len(t[0]) == 0) t.set_label(f"arg:{t[0].label()}") del t[0] if t.label().startswith( "edu.stanford.nlp.sempre.overnight.SimpleWorld."): t.set_label("SW:" + t.label( )[len("edu.stanford.nlp.sempre.overnight.SimpleWorld."):]) if t.label() == "SW:getProperty": assert (len(t) == 2) ret = simplify_tree(t[1]) ret.append(simplify_tree(t[0])) return ret elif t.label() == "SW:singleton": assert (len(t) == 1) assert (len(t[0]) == 0) return t[0] elif t.label() == "SW:ensureNumericProperty": assert (len(t) == 1) return simplify_tree(t[0]) elif t.label() == "SW:ensureNumericEntity": assert (len(t) == 1) return simplify_tree(t[0]) elif t.label() == "SW:aggregate": assert (len(t) == 2) ret = simplify_tree(t[0]) assert (ret.label() in ["arg:avg", "arg:sum"]) assert (len(ret) == 0) ret.set_label(f"agg:{ret.label()}") ret.append(simplify_tree(t[1])) return ret else: t[:] = [simplify_tree(tc) for tc in t] return t def simplify_further(t): """ simplifies filters and count expressions """ # replace filters with ands if t.label() == "SW:filter" and self._simplify_filters is True: if len(t) not in (2, 4): raise Exception( f"filter expression should have 2 or 4 children, got {len(children)}" ) children = [simplify_further(tc) for tc in t] startset = children[0] if len(children) == 2: condition = Tree("cond:has", [children[1]]) elif len(children) == 4: condition = Tree(f"cond:{children[2].label()}", [children[1], children[3]]) conditions = [condition] if startset.label() == "op:and": conditions = startset[:] + conditions else: conditions = [startset] + conditions # check for same conditions: i = 0 while i < len(conditions) - 1: j = i + 1 while j < len(conditions): if conditions[i] == conditions[j]: print(f"SAME!: {conditions[i]}, {conditions[j]}") del conditions[j] j -= 1 j += 1 i += 1 ret = Tree(f"op:and", conditions) return ret # replace countSuperlatives with specific ones elif t.label() == "SW:countSuperlative": assert (t[1].label() in ["arg:max", "arg:min"]) t.set_label(f"SW:CNT-{t[1].label()}") del t[1] t[:] = [simplify_further(tc) for tc in t] elif t.label() == "SW:countComparative": assert (t[2].label() in [ "arg:<", "arg:<=", "arg:>", "arg:>=", "arg:=", "arg:!=" ]) t.set_label(f"SW:CNT-{t[2].label()}") del t[2] t[:] = [simplify_further(tc) for tc in t] else: t[:] = [simplify_further(tc) for tc in t] return t def simplify_furthermore(t): """ replace reverse rels""" if t.label() == "arg:!type": t.set_label("arg:~type") return t elif t.label() == "SW:reverse": assert (len(t) == 1) assert (t[0].label().startswith("arg:")) assert (len(t[0]) == 0) t.set_label(f"arg:~{t[0].label()[4:]}") del t[0] return t elif t.label().startswith("cond:arg:"): assert (len(t) == 2) head = t[0] head = simplify_furthermore(head) assert (head.label().startswith("arg:")) assert (len(head) == 0) headlabel = f"arg:~{head.label()[4:]}" headlabel = headlabel.replace("~~", "") head.set_label(headlabel) body = simplify_furthermore(t[1]) if t.label()[len("cond:arg:"):] != "=": body = Tree(t.label()[5:], [body]) head.append(body) return head else: t[:] = [simplify_furthermore(tc) for tc in t] return t def simplify_final(t): assert (t.label() == "SW:listValue") assert (len(t) == 1) return t[0] ret = [] ltp = None j = 0 for i, line in enumerate(lines): z, ltp = lisp_to_pas(line, ltp) if z is not None: print(f"Example {j}:") ztree = pas_to_tree(z[1][2][1][0]) maxsize_before = max(maxsize_before, tree_size(ztree)) avgsize_before.append(tree_size(ztree)) lf = simplify_tree(ztree) lf = simplify_further(lf) lf = simplify_furthermore(lf) lf = simplify_final(lf) question = z[1][0][1][0] assert (question[0] == '"' and question[-1] == '"') ret.append((question[1:-1], lf)) print(ret[-1][0]) print(ret[-1][1]) ltp = None maxsize_after = max(maxsize_after, tree_size(lf)) avgsize_after.append(tree_size(lf)) print(pas_to_tree(z[1][2][1][0])) print() j += 1 avgsize_before = sum(avgsize_before) / len(avgsize_before) avgsize_after = sum(avgsize_after) / len(avgsize_after) print("Simplification results ({j} examples):") print(f"\t Max, Avg size before: {maxsize_before}, {avgsize_before}") print(f"\t Max, Avg size after: {maxsize_after}, {avgsize_after}") return ret def _load_cached(self): train_cached = ujson.load( open(os.path.join(self._pcache, f"{self._domain}.train.json"), "r")) trainexamples = [(x, Tree.fromstring(y)) for x, y in train_cached] test_cached = ujson.load( open(os.path.join(self._pcache, f"{self._domain}.test.json"), "r")) testexamples = [(x, Tree.fromstring(y)) for x, y in test_cached] print("loaded from cache") return trainexamples, testexamples def _cache(self, trainexamples: List[Tuple[str, Tree]], testexamples: List[Tuple[str, Tree]]): train_cached, test_cached = None, None if os.path.exists( os.path.join(self._pcache, f"{self._domain}.train.json")): try: train_cached = ujson.load( open( os.path.join(self._pcache, f"{self._domain}.train.json"), "r")) test_cached = ujson.load( open( os.path.join(self._pcache, f"{self._domain}.test.json"), "r")) except (IOError, ValueError) as e: pass trainexamples = [(x, str(y)) for x, y in trainexamples] testexamples = [(x, str(y)) for x, y in testexamples] if train_cached != trainexamples: with open(os.path.join(self._pcache, f"{self._domain}.train.json"), "w") as f: ujson.dump(trainexamples, f, indent=4, sort_keys=True) if test_cached != testexamples: with open(os.path.join(self._pcache, f"{self._domain}.test.json"), "w") as f: ujson.dump(testexamples, f, indent=4, sort_keys=True) print("saved in cache") def _initialize(self, p, domain, sentence_encoder: SequenceEncoder, min_freq: int): self.data = {} self.sentence_encoder = sentence_encoder trainexamples, testexamples = None, None if self._usecache: try: trainexamples, testexamples = self._load_cached() except (IOError, ValueError) as e: pass if trainexamples is None: trainlines = [ x.strip() for x in open( os.path.join(p, f"{domain}.paraphrases.train.examples"), "r").readlines() ] testlines = [ x.strip() for x in open( os.path.join(p, f"{domain}.paraphrases.test.examples"), "r").readlines() ] trainexamples = self.lines_to_examples(trainlines) testexamples = self.lines_to_examples(testlines) if self._usecache: self._cache(trainexamples, testexamples) questions, queries = tuple(zip(*(trainexamples + testexamples))) trainlen = int(round(0.8 * len(trainexamples))) validlen = int(round(0.2 * len(trainexamples))) splits = ["train"] * trainlen + ["valid"] * validlen # random.seed(1223) random.shuffle(splits) assert (len(splits) == len(trainexamples)) splits = splits + ["test"] * len(testexamples) self.query_encoder = SequenceEncoder(tokenizer=partial( tree_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split == "train") self.query_encoder.inc_build_vocab(query, seen=split == "train") for word, wordid in self.sentence_encoder.vocab.D.items(): self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq) self.query_encoder.finalize_vocab(min_freq=min_freq) self.build_data(questions, queries, splits) def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 eid = 0 for inp, out, split in zip(inputs, outputs, splits): state = TreeDecoderState([inp], [out], self.sentence_encoder, self.query_encoder) state.eids = np.asarray([eid], dtype="int64") maxlen_in, maxlen_out = max(maxlen_in, len(state.inp_tokens[0])), max( maxlen_out, len(state.gold_tokens[0])) if split not in self.data: self.data[split] = [] self.data[split].append(state) eid += 1 self.maxlen_input, self.maxlen_output = maxlen_in, maxlen_out def get_split(self, split: str): splits = split.split("+") data = [] for split in splits: data += self.data[split] return DatasetSplitProxy(data) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split) return ret else: dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split in ("train", "train+valid"), collate_fn=OvernightDataset.collate_fn) return dl
class GeoQueryDatasetFunQL(object): def __init__(self, p="../../datasets/geoquery/", sentence_encoder: SequenceEncoder = None, min_freq: int = 2, **kw): super(GeoQueryDatasetFunQL, self).__init__(**kw) self._initialize(p, sentence_encoder, min_freq) def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int): self.data = {} self.sentence_encoder = sentence_encoder questions = [ x.strip() for x in open(os.path.join(p, "questions.txt"), "r").readlines() ] queries = [ x.strip() for x in open(os.path.join(p, "queries.funql"), "r").readlines() ] trainidxs = set([ int(x.strip()) for x in open(os.path.join(p, "train_indexes.txt"), "r").readlines() ]) testidxs = set([ int(x.strip()) for x in open(os.path.join(p, "test_indexes.txt"), "r").readlines() ]) splits = [None] * len(questions) for trainidx in trainidxs: splits[trainidx] = "train" for testidx in testidxs: splits[testidx] = "test" if any([split == None for split in splits]): print( f"{len([split for split in splits if split == None])} examples not assigned to any split" ) self.query_encoder = SequenceEncoder(tokenizer=partial( basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies unktokens = set() for i, (question, query, split) in enumerate(zip(questions, queries, splits)): question_tokens = self.sentence_encoder.inc_build_vocab( question, seen=split == "train") query_tokens = self.query_encoder.inc_build_vocab( query, seen=split == "train") unktokens |= set(query_tokens) - set(question_tokens) for word in self.sentence_encoder.vocab.counts.keys(): self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True) self.query_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True) unktokens = unktokens & self.query_encoder.vocab.rare_tokens self.build_data(questions, queries, splits, unktokens=unktokens) def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str], unktokens: Set[str] = None): if unktokens is not None: gold_map = torch.arange( 0, self.query_encoder.vocab.number_of_ids(last_nonrare=False)) for rare_token in unktokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = prolog_to_tree(out) assert (gold_tree is not None) out_tensor, out_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab) if split not in self.data: self.data[split] = [] self.data[split].append(state) def get_split(self, split: str): return DatasetSplitProxy(self.data[split]) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split) return ret else: assert (split in self.data.keys()) dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split == "train", collate_fn=GeoQueryDatasetFunQL.collate_fn) return dl