class ConditionalRecallDataset(object): def __init__(self, maxlen=10, NperY=10, **kw): super(ConditionalRecallDataset, self).__init__(**kw) self.data = {} self.NperY, self.maxlen = NperY, maxlen self._seqs, self._ys = gen_data(self.maxlen, self.NperY) self.encoder = SequenceEncoder(tokenizer=lambda x: list(x)) for seq, y in zip(self._seqs, self._ys): self.encoder.inc_build_vocab(seq) self.encoder.inc_build_vocab(y) self.N = len(self._seqs) N = self.N splits = ["train"] * int(N * 0.8) + ["valid"] * int( N * 0.1) + ["test"] * int(N * 0.1) random.shuffle(splits) self.encoder.finalize_vocab() self.build_data(self._seqs, self._ys, splits) def build_data(self, seqs, ys, splits): for seq, y, split in zip(seqs, ys, splits): seq_tensor = self.encoder.convert(seq, return_what="tensor") y_tensor = self.encoder.convert(y, return_what="tensor") if split not in self.data: self.data[split] = [] self.data[split].append((seq_tensor[0], y_tensor[0][0])) def get_split(self, split: str): return DatasetSplitProxy(self.data[split]) def dataloader(self, split: str = None, batsize: int = 5, shuffle=None): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split, shuffle=shuffle) return ret else: assert (split in self.data.keys()) shuffle = shuffle if shuffle is not None else split in ( "train", "train+valid") dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=shuffle) return dl
def __init__(self, inp_strings:List[str]=None, gold_strings:List[str]=None, inp_tensor:torch.Tensor=None, gold_tensor:torch.Tensor=None, inp_tokens:List[List[str]]=None, gold_tokens:List[List[str]]=None, sentence_encoder:SequenceEncoder=None, query_encoder:SequenceEncoder=None, **kw): if inp_strings is None: super(BasicDecoderState, self).__init__(**kw) else: kw = kw.copy() kw.update({"inp_strings": np.asarray(inp_strings), "gold_strings": np.asarray(gold_strings)}) super(BasicDecoderState, self).__init__(**kw) self.sentence_encoder = sentence_encoder self.query_encoder = query_encoder # self.set(followed_actions_str = np.asarray([None for _ in self.inp_strings])) # for i in range(len(self.followed_actions_str)): # self.followed_actions_str[i] = [] self.set(followed_actions = torch.zeros(len(inp_strings), 0, dtype=torch.long)) self.set(_is_terminated = np.asarray([False for _ in self.inp_strings])) self.set(_timesteps = np.asarray([0 for _ in self.inp_strings])) if sentence_encoder is not None: x = [sentence_encoder.convert(x, return_what="tensor,tokens") for x in self.inp_strings] x = list(zip(*x)) inp_tokens = np.asarray([None for _ in range(len(x[1]))], dtype=np.object) for i, inp_tokens_e in enumerate(x[1]): inp_tokens[i] = tuple(inp_tokens_e) x = {"inp_tensor": batchstack(x[0]), "inp_tokens": inp_tokens} self.set(**x) if self.gold_strings is not None: if query_encoder is not None: x = [query_encoder.convert(x, return_what="tensor,tokens") for x in self.gold_strings] x = list(zip(*x)) gold_tokens = np.asarray([None for _ in range(len(x[1]))]) for i, gold_tokens_e in enumerate(x[1]): gold_tokens[i] = tuple(gold_tokens_e) x = {"gold_tensor": batchstack(x[0]), "gold_tokens": gold_tokens} self.set(**x)
def try_perturbed_generated_dataset(): torch.manual_seed(1234) ovd = OvernightDatasetLoader().load() govd = PCFGDataset(OvernightPCFGBuilder() .build(ovd[(None, None, lambda x: x in {"train", "valid"})] .map(lambda f: f[1]).examples), N=10000) print(govd[0]) # print(govd[lambda x: True][0]) # print(govd[:]) # create vocab from pcfg vocab = build_vocab_from_pcfg(govd._pcfg) seqenc = SequenceEncoder(vocab=vocab, tokenizer=tree_to_lisp_tokens) spanmasker = SpanMasker(seed=12345667) treemasker = SubtreeMasker(p=.05, seed=2345677) perturbed_govd = govd.cache()\ .map(lambda x: (seqenc.convert(x, "tensor"), x)) \ .map(lambda x: x + (seqenc.convert(x[-1], "tokens"),)) \ .map(lambda x: x + (spanmasker(x[-1]),)) \ .map(lambda x: x + (seqenc.convert(x[-1], "tensor"),)) \ .map(lambda x: (x[-1], x[0])) dl = DataLoader(perturbed_govd, batch_size=10, shuffle=True, collate_fn=pad_and_default_collate) batch = next(iter(dl)) print(batch) print(vocab.tostr(batch[0][1])) print(vocab.tostr(batch[1][1])) tt = q.ticktock() tt.tick("first run") for i in range(10000): y = perturbed_govd[i] if i < 10: print(f"{y[0]}\n{y[-2]}") tt.tock("first run done") tt.tick("second run") for i in range(10000): y = perturbed_govd[i] if i < 10: print(f"{y[0]}\n{y[-2]}") tt.tock("second run done")
class GeoDatasetRank(object): def __init__(self, p="geoquery_gen/run4/", min_freq: int = 2, splits=None, **kw): super(GeoDatasetRank, self).__init__(**kw) self._initialize(p) self.splits_proportions = splits def _initialize(self, p): self.data = {} with open(os.path.join(p, "trainpreds.json")) as f: trainpreds = ujson.load(f) with open(os.path.join(p, "testpreds.json")) as f: testpreds = ujson.load(f) splits = ["train"] * len(trainpreds) + ["test"] * len(testpreds) preds = trainpreds + testpreds self.sentence_encoder = SequenceEncoder(tokenizer=lambda x: x.split()) self.query_encoder = SequenceEncoder(tokenizer=lambda x: x.split()) # build vocabularies for i, (example, split) in enumerate(zip(preds, splits)): self.sentence_encoder.inc_build_vocab(" ".join( example["sentence"]), seen=split == "train") self.query_encoder.inc_build_vocab(" ".join(example["gold"]), seen=split == "train") for can in example["candidates"]: self.query_encoder.inc_build_vocab(" ".join(can["tokens"]), seen=False) # for word, wordid in self.sentence_encoder.vocab.D.items(): # self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab() self.query_encoder.finalize_vocab() self.build_data(preds, splits) def build_data(self, examples: Iterable[dict], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 for example, split in zip(examples, splits): inp, out = " ".join(example["sentence"]), " ".join(example["gold"]) inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(" ".join(example["gold"][:-1])) if not isinstance(gold_tree, Tree): assert (gold_tree is not None) gold_tensor, gold_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], [] candidate_align_entropies = [] candidate_trees = [] candidate_same = [] for cand in example["candidates"]: cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]), None) if cand_tree is None: cand_tree = Tree("@UNK@", []) assert (cand_tree is not None) cand_tensor, cand_tokens = self.query_encoder.convert( " ".join(cand["tokens"]), return_what="tensor,tokens") candidate_tensors.append(cand_tensor) candidate_tokens.append(cand_tokens) candidate_align_tensors.append(torch.tensor( cand["alignments"])) candidate_align_entropies.append( torch.tensor(cand["align_entropies"])) candidate_trees.append(cand_tree) candidate_same.append( are_equal_trees(cand_tree, gold_tree, orderless={"and", "or"}, unktoken="@NOUNKTOKENHERE@")) candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0), 0) candidate_align_tensor = torch.stack( q.pad_tensors(candidate_align_tensors, 0), 0) candidate_align_entropy = torch.stack( q.pad_tensors(candidate_align_entropies, 0), 0) candidate_same = torch.tensor(candidate_same) state = RankState( inp_tensor[None, :], gold_tensor[None, :], candidate_tensor[None, :, :], candidate_same[None, :], candidate_align_tensor[None, :], candidate_align_entropy[None, :], self.sentence_encoder.vocab, self.query_encoder.vocab, ) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, candidate_tensor.size(-1), gold_tensor.size(-1)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out def get_split(self, split: str): return DatasetSplitProxy(self.data[split]) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) goldmaxlen = max(goldmaxlen, state.candtensors.size(-1)) inp_tensors = q.pad_tensors([state.inp_tensor for state in data], 1, 0) gold_tensors = q.pad_tensors([state.gold_tensor for state in data], 1, 0) candtensors = q.pad_tensors([state.candtensors for state in data], 2, 0) alignments = q.pad_tensors([state.alignments for state in data], 2, 0) alignment_entropies = q.pad_tensors( [state.alignment_entropies for state in data], 2, 0) for i, state in enumerate(data): state.inp_tensor = inp_tensors[i] state.gold_tensor = gold_tensors[i] state.candtensors = candtensors[i] state.alignments = alignments[i] state.alignment_entropies = alignment_entropies[i] ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5, shuffle=None): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split, shuffle=shuffle) return ret else: assert (split in self.data.keys()) shuffle = shuffle if shuffle is not None else split in ( "train", "train+valid") dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=shuffle, collate_fn=type(self).collate_fn) return dl
class LCQuaDnoENTDataset(object): def __init__(self, p="../../datasets/lcquad/", sentence_encoder: SequenceEncoder = None, min_freq: int = 2, splits=None, **kw): super(LCQuaDnoENTDataset, self).__init__(**kw) self._simplify_filters = True # if True, filter expressions are converted to orderless and-expressions self._initialize(p, sentence_encoder, min_freq) self.splits_proportions = splits def lines_to_examples(self, lines: List[str]): maxsize_before = 0 avgsize_before = [] maxsize_after = 0 avgsize_after = [] afterstring = set() def convert_to_lispstr(_x): splits = _x.split() assert (sum([1 if xe == "~" else 0 for xe in splits]) == 1) assert (splits[1] == "~") splits = ["," if xe == "&" else xe for xe in splits] pstr = f"{splits[0]} ({' '.join(splits[2:])})" return pstr ret = [] ltp = None j = 0 for i, line in enumerate(lines): question = line["question"] query = line["logical_form"] query = convert_to_lispstr(query) z, ltp = prolog_to_pas(query, ltp) if z is not None: ztree = pas_to_tree(z) maxsize_before = max(maxsize_before, tree_size(ztree)) avgsize_before.append(tree_size(ztree)) lf = ztree ret.append((question, lf)) # print(f"Example {j}:") # print(ret[-1][0]) # print(ret[-1][1]) # print() ltp = None maxsize_after = max(maxsize_after, tree_size(lf)) avgsize_after.append(tree_size(lf)) j += 1 avgsize_before = sum(avgsize_before) / len(avgsize_before) avgsize_after = sum(avgsize_after) / len(avgsize_after) print("Sizes ({j} examples):") # print(f"\t Max, Avg size before: {maxsize_before}, {avgsize_before}") print(f"\t Max, Avg size: {maxsize_after}, {avgsize_after}") return ret def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int): self.data = {} self.sentence_encoder = sentence_encoder jp = os.path.join(p, "lcquad_dataset.json") with open(jp, "r") as f: examples = ujson.load(f) examples = self.lines_to_examples(examples) questions, queries = tuple(zip(*examples)) trainlen = int(round(0.8 * len(examples))) validlen = int(round(0.1 * len(examples))) testlen = int(round(0.1 * len(examples))) splits = ["train"] * trainlen + ["valid"] * validlen + ["test" ] * testlen random.seed(1337) random.shuffle(splits) assert (len(splits) == len(examples)) self.query_encoder = SequenceEncoder(tokenizer=partial( tree_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split == "train") self.query_encoder.inc_build_vocab(query, seen=split == "train") for word, wordid in self.sentence_encoder.vocab.D.items(): self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq) self.query_encoder.finalize_vocab(min_freq=min_freq) self.build_data(questions, queries, splits) def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 eid = 0 gold_map = torch.arange( 0, self.query_encoder.vocab.number_of_ids(last_nonrare=False)) rare_tokens = self.query_encoder.vocab.rare_tokens - set( self.sentence_encoder.vocab.D.keys()) for rare_token in rare_tokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") out_tensor, out_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [out], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab) state.eids = np.asarray([eid], dtype="int64") maxlen_in, maxlen_out = max(maxlen_in, len(state.inp_tokens[0])), max( maxlen_out, len(state.gold_tokens[0])) if split not in self.data: self.data[split] = [] self.data[split].append(state) eid += 1 self.maxlen_input, self.maxlen_output = maxlen_in, maxlen_out def get_split(self, split: str): splits = split.split("+") data = [] for split in splits: data += self.data[split] return DatasetSplitProxy(data) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split) return ret else: dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split in ("train", "train+valid"), collate_fn=type(self).collate_fn) return dl
class GeoDataset(object): def __init__(self, p="../../datasets/geo880dong/", sentence_encoder:SequenceEncoder=None, min_freq:int=2, splits=None, **kw): super(GeoDataset, self).__init__(**kw) self._initialize(p, sentence_encoder, min_freq) self.splits_proportions = splits def _initialize(self, p, sentence_encoder:SequenceEncoder, min_freq:int): self.data = {} self.sentence_encoder = sentence_encoder trainlines = [x.strip() for x in open(os.path.join(p, "train.txt"), "r").readlines()] testlines = [x.strip() for x in open(os.path.join(p, "test.txt"), "r").readlines()] splits = ["train"]*len(trainlines) + ["test"] * len(testlines) questions, queries = zip(*[x.split("\t") for x in trainlines]) testqs, testxs = zip(*[x.split("\t") for x in testlines]) questions += testqs queries += testxs self.query_encoder = SequenceEncoder(tokenizer=partial(basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split=="train") self.query_encoder.inc_build_vocab(query, seen=split=="train") # for word, wordid in self.sentence_encoder.vocab.D.items(): # self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True) self.query_encoder.finalize_vocab(min_freq=min_freq) self.build_data(questions, queries, splits) def build_data(self, inputs:Iterable[str], outputs:Iterable[str], splits:Iterable[str], unktokens:Set[str]=None): gold_map = None maxlen_in, maxlen_out = 0, 0 if unktokens is not None: gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids(last_nonrare=False)) for rare_token in unktokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert(inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(out) assert(gold_tree is not None) out_tensor, out_tokens = self.query_encoder.convert(out, return_what="tensor,tokens") if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, len(out_tensor)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out def get_split(self, split:str): return DatasetSplitProxy(self.data[split]) @staticmethod def collate_fn(data:Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros(1, goldmaxlen - state.gold_tensor.size(1))], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros(1, inpmaxlen - state.inp_tensor.size(1))], 1) ret = data[0].merge(data) return ret def dataloader(self, split:str=None, batsize:int=5): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split) return ret else: assert(split in self.data.keys()) dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split=="train", collate_fn=type(self).collate_fn) return dl
class GeoDataset(object): def __init__(self, p="../../datasets/geo880dong/", sentence_encoder: SequenceEncoder = None, min_freq: int = 2, cvfolds=None, testfold=None, reorder_random=False, **kw): super(GeoDataset, self).__init__(**kw) self.cvfolds, self.testfold = cvfolds, testfold self.reorder_random = reorder_random self._initialize(p, sentence_encoder, min_freq) def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int): self.data = {} self.sentence_encoder = sentence_encoder trainlines = [ x.strip() for x in open(os.path.join(p, "train.txt"), "r").readlines() ] testlines = [ x.strip() for x in open(os.path.join(p, "test.txt"), "r").readlines() ] if self.cvfolds is None: splits = ["train"] * len(trainlines) + ["test"] * len(testlines) else: cvsplit_len = len(trainlines) / self.cvfolds splits = [] for i in range(0, self.cvfolds): splits += [i] * round(cvsplit_len * (i + 1) - len(splits)) random.shuffle(splits) splits = [ "valid" if x == self.testfold else "train" for x in splits ] splits = splits + ["test"] * len(testlines) questions, queries = zip(*[x.split("\t") for x in trainlines]) testqs, testxs = zip(*[x.split("\t") for x in testlines]) questions += testqs queries += testxs self.query_encoder = SequenceEncoder(tokenizer=partial( basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.sentence_encoder.inc_build_vocab(question, seen=split == "train") self.query_encoder.inc_build_vocab(query, seen=split == "train") # for word, wordid in self.sentence_encoder.vocab.D.items(): # self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True) self.query_encoder.finalize_vocab(min_freq=min_freq) token_specs = self.build_token_specs(queries) self.token_specs = token_specs self.build_data(questions, queries, splits) def build_token_specs(self, outputs: Iterable[str]): token_specs = dict() def walk_the_tree(t, _ts): l = t.label() if l not in _ts: _ts[l] = [np.infty, -np.infty] minc, maxc = _ts[l] _ts[l] = [min(minc, len(t)), max(maxc, len(t))] for c in t: walk_the_tree(c, _ts) for out in outputs: out_tokens = self.query_encoder.convert(out, return_what="tokens")[0] assert (out_tokens[-1] == "@END@") out_tokens = out_tokens[:-1] out_str = " ".join(out_tokens) tree = lisp_to_tree(out_str) walk_the_tree(tree, token_specs) token_specs["and"][1] = np.infty return token_specs def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str], unktokens: Set[str] = None): gold_map = None maxlen_in, maxlen_out = 0, 0 if unktokens is not None: gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids()) for rare_token in unktokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(out) assert (gold_tree is not None) out_tensor, out_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab, token_specs=self.token_specs) if split == "train" and self.reorder_random is True: gold_tree_ = tensor2tree(out_tensor, self.query_encoder.vocab) random_gold_tree = random.choice( get_tree_permutations(gold_tree_, orderless={"and"})) out_ = tree_to_lisp(random_gold_tree) out_tensor_, out_tokens_ = self.query_encoder.convert( out_, return_what="tensor,tokens") if gold_map is not None: out_tensor_ = gold_map[out_tensor_] state.gold_tensor = out_tensor_[None] if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, len(out_tensor)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out def get_split(self, split: str): data = [] for split_e in split.split("+"): data += self.data[split_e] return DatasetSplitProxy(data) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5, shuffle=None): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split, shuffle=shuffle) return ret else: # assert(split in self.data.keys()) shuffle = shuffle if shuffle is not None else split in ( "train", "train+valid") dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=shuffle, collate_fn=type(self).collate_fn) return dl
class GeoDataset(object): def __init__(self, p="../../datasets/geo880_multiling/geoquery/", train_lang="en", test_lang=None, bert_tokenizer=None, min_freq: int = 2, cvfolds=None, testfold=None, **kw): super(GeoDataset, self).__init__(**kw) self.train_lang = train_lang self.test_lang = test_lang if test_lang is not None else train_lang self.cvfolds, self.testfold = cvfolds, testfold self._initialize(p, bert_tokenizer, min_freq) def _initialize(self, p, bert_tokenizer, min_freq: int): self.data = {} self.bert_vocab = Vocab() self.bert_vocab.set_dict(bert_tokenizer.vocab) self.sentence_encoder = SequenceEncoder( lambda x: bert_tokenizer.tokenize(f"[CLS] {x} [SEP]"), vocab=self.bert_vocab) trainlines = [ x for x in ujson.load( open(os.path.join(p, f"geo-{self.train_lang}.json"), "r")) ] testlines = [ x for x in ujson.load( open(os.path.join(p, f"geo-{self.train_lang}.json"), "r")) ] trainlines = [x for x in trainlines if x["split"] == "train"] testlines = [x for x in testlines if x["split"] == "test"] if self.cvfolds is None: splits = ["train"] * len(trainlines) + ["test"] * len(testlines) else: cvsplit_len = len(trainlines) / self.cvfolds splits = [] for i in range(0, self.cvfolds): splits += [i] * round(cvsplit_len * (i + 1) - len(splits)) random.shuffle(splits) splits = [ "valid" if x == self.testfold else "train" for x in splits ] splits = splits + ["test"] * len(testlines) questions = [x["nl"] for x in trainlines] queries = [x["mrl"] for x in trainlines] xquestions = [x["nl"] for x in testlines] xqueries = [x["mrl"] for x in testlines] questions += xquestions queries += xqueries # initialize output vocabulary outvocab = Vocab() for token, bertid in self.bert_vocab.D.items(): outvocab.add_token(token, seen=False) self.query_encoder = SequenceEncoder(tokenizer=partial( basic_query_tokenizer, strtok=bert_tokenizer), vocab=outvocab, add_end_token=True) # build vocabularies for i, (question, query, split) in enumerate(zip(questions, queries, splits)): self.query_encoder.inc_build_vocab(query, seen=split == "train") keeptokens = set(self.bert_vocab.D.keys()) self.query_encoder.finalize_vocab(min_freq=min_freq, keep_tokens=keeptokens) token_specs = self.build_token_specs(queries) self.token_specs = token_specs self.build_data(questions, queries, splits) def build_token_specs(self, outputs: Iterable[str]): token_specs = dict() def walk_the_tree(t, _ts): l = t.label() if l not in _ts: _ts[l] = [np.infty, -np.infty] minc, maxc = _ts[l] _ts[l] = [min(minc, len(t)), max(maxc, len(t))] for c in t: walk_the_tree(c, _ts) for out in outputs: out_tokens = self.query_encoder.convert(out, return_what="tokens")[0] assert (out_tokens[-1] == "@END@") out_tokens = out_tokens[:-1] out_str = " ".join(out_tokens) tree = lisp_to_tree(out_str) walk_the_tree(tree, token_specs) # token_specs["and"][1] = np.infty return token_specs def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 for inp, out, split in zip(inputs, outputs, splits): # tokenize both input and output inp_tokens = self.sentence_encoder.convert(inp, return_what="tokens")[0] out_tokens = self.query_encoder.convert(out, return_what="tokens")[0] # get gold tree gold_tree = lisp_to_tree(" ".join(out_tokens[:-1])) assert (gold_tree is not None) # replace words in output that can't be copied from given input to UNK tokens unktoken = self.query_encoder.vocab.unktoken inp_tokens_ = set(inp_tokens) out_tokens = [ out_token if out_token in inp_tokens_ or (out_token in self.query_encoder.vocab and not out_token in self.query_encoder.vocab.rare_tokens) else unktoken for out_token in out_tokens ] # convert token sequences to ids inp_tensor = self.sentence_encoder.convert(inp_tokens, return_what="tensor")[0] out_tensor = self.query_encoder.convert(out_tokens, return_what="tensor")[0] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab, token_specs=self.token_specs) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, len(out_tensor)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out def get_split(self, split: str): data = [] for split_e in split.split("+"): data += self.data[split_e] return DatasetSplitProxy(data) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5, shuffle=None): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split, shuffle=shuffle) return ret else: # assert(split in self.data.keys()) shuffle = shuffle if shuffle is not None else split in ( "train", "train+valid") dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=shuffle, collate_fn=type(self).collate_fn) return dl
class GeoQueryDatasetFunQL(object): def __init__(self, p="../../datasets/geoquery/", sentence_encoder: SequenceEncoder = None, min_freq: int = 2, **kw): super(GeoQueryDatasetFunQL, self).__init__(**kw) self._initialize(p, sentence_encoder, min_freq) def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int): self.data = {} self.sentence_encoder = sentence_encoder questions = [ x.strip() for x in open(os.path.join(p, "questions.txt"), "r").readlines() ] queries = [ x.strip() for x in open(os.path.join(p, "queries.funql"), "r").readlines() ] trainidxs = set([ int(x.strip()) for x in open(os.path.join(p, "train_indexes.txt"), "r").readlines() ]) testidxs = set([ int(x.strip()) for x in open(os.path.join(p, "test_indexes.txt"), "r").readlines() ]) splits = [None] * len(questions) for trainidx in trainidxs: splits[trainidx] = "train" for testidx in testidxs: splits[testidx] = "test" if any([split == None for split in splits]): print( f"{len([split for split in splits if split == None])} examples not assigned to any split" ) self.query_encoder = SequenceEncoder(tokenizer=partial( basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True) # build vocabularies unktokens = set() for i, (question, query, split) in enumerate(zip(questions, queries, splits)): question_tokens = self.sentence_encoder.inc_build_vocab( question, seen=split == "train") query_tokens = self.query_encoder.inc_build_vocab( query, seen=split == "train") unktokens |= set(query_tokens) - set(question_tokens) for word in self.sentence_encoder.vocab.counts.keys(): self.query_encoder.vocab.add_token(word, seen=False) self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True) self.query_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True) unktokens = unktokens & self.query_encoder.vocab.rare_tokens self.build_data(questions, queries, splits, unktokens=unktokens) def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str], unktokens: Set[str] = None): if unktokens is not None: gold_map = torch.arange( 0, self.query_encoder.vocab.number_of_ids(last_nonrare=False)) for rare_token in unktokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = prolog_to_tree(out) assert (gold_tree is not None) out_tensor, out_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab) if split not in self.data: self.data[split] = [] self.data[split].append(state) def get_split(self, split: str): return DatasetSplitProxy(self.data[split]) @staticmethod def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) for state in data: state.gold_tensor = torch.cat([ state.gold_tensor, state.gold_tensor.new_zeros( 1, goldmaxlen - state.gold_tensor.size(1)) ], 1) state.inp_tensor = torch.cat([ state.inp_tensor, state.inp_tensor.new_zeros( 1, inpmaxlen - state.inp_tensor.size(1)) ], 1) ret = data[0].merge(data) return ret def dataloader(self, split: str = None, batsize: int = 5): if split is None: # return all splits ret = {} for split in self.data.keys(): ret[split] = self.dataloader(batsize=batsize, split=split) return ret else: assert (split in self.data.keys()) dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split == "train", collate_fn=GeoQueryDatasetFunQL.collate_fn) return dl