Пример #1
0
class ConditionalRecallDataset(object):
    def __init__(self, maxlen=10, NperY=10, **kw):
        super(ConditionalRecallDataset, self).__init__(**kw)
        self.data = {}
        self.NperY, self.maxlen = NperY, maxlen
        self._seqs, self._ys = gen_data(self.maxlen, self.NperY)
        self.encoder = SequenceEncoder(tokenizer=lambda x: list(x))

        for seq, y in zip(self._seqs, self._ys):
            self.encoder.inc_build_vocab(seq)
            self.encoder.inc_build_vocab(y)

        self.N = len(self._seqs)
        N = self.N

        splits = ["train"] * int(N * 0.8) + ["valid"] * int(
            N * 0.1) + ["test"] * int(N * 0.1)
        random.shuffle(splits)

        self.encoder.finalize_vocab()
        self.build_data(self._seqs, self._ys, splits)

    def build_data(self, seqs, ys, splits):
        for seq, y, split in zip(seqs, ys, splits):
            seq_tensor = self.encoder.convert(seq, return_what="tensor")
            y_tensor = self.encoder.convert(y, return_what="tensor")
            if split not in self.data:
                self.data[split] = []
            self.data[split].append((seq_tensor[0], y_tensor[0][0]))

    def get_split(self, split: str):
        return DatasetSplitProxy(self.data[split])

    def dataloader(self, split: str = None, batsize: int = 5, shuffle=None):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize,
                                             split=split,
                                             shuffle=shuffle)
            return ret
        else:
            assert (split in self.data.keys())
            shuffle = shuffle if shuffle is not None else split in (
                "train", "train+valid")
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=shuffle)
            return dl
Пример #2
0
    def __init__(self,
                 inp_strings:List[str]=None,
                 gold_strings:List[str]=None,
                 inp_tensor:torch.Tensor=None,
                 gold_tensor:torch.Tensor=None,
                 inp_tokens:List[List[str]]=None,
                 gold_tokens:List[List[str]]=None,
                 sentence_encoder:SequenceEncoder=None,
                 query_encoder:SequenceEncoder=None,
                 **kw):
        if inp_strings is None:
            super(BasicDecoderState, self).__init__(**kw)
        else:
            kw = kw.copy()
            kw.update({"inp_strings": np.asarray(inp_strings), "gold_strings": np.asarray(gold_strings)})
            super(BasicDecoderState, self).__init__(**kw)

            self.sentence_encoder = sentence_encoder
            self.query_encoder = query_encoder

            # self.set(followed_actions_str = np.asarray([None for _ in self.inp_strings]))
            # for i in range(len(self.followed_actions_str)):
            #     self.followed_actions_str[i] = []
            self.set(followed_actions = torch.zeros(len(inp_strings), 0, dtype=torch.long))
            self.set(_is_terminated = np.asarray([False for _ in self.inp_strings]))
            self.set(_timesteps = np.asarray([0 for _ in self.inp_strings]))

            if sentence_encoder is not None:
                x = [sentence_encoder.convert(x, return_what="tensor,tokens") for x in self.inp_strings]
                x = list(zip(*x))
                inp_tokens = np.asarray([None for _ in range(len(x[1]))], dtype=np.object)
                for i, inp_tokens_e in enumerate(x[1]):
                    inp_tokens[i] = tuple(inp_tokens_e)
                x = {"inp_tensor": batchstack(x[0]),
                     "inp_tokens": inp_tokens}
                self.set(**x)
            if self.gold_strings is not None:
                if query_encoder is not None:
                    x = [query_encoder.convert(x, return_what="tensor,tokens") for x in self.gold_strings]
                    x = list(zip(*x))
                    gold_tokens = np.asarray([None for _ in range(len(x[1]))])
                    for i, gold_tokens_e in enumerate(x[1]):
                        gold_tokens[i] = tuple(gold_tokens_e)
                    x = {"gold_tensor": batchstack(x[0]),
                         "gold_tokens": gold_tokens}
                    self.set(**x)
Пример #3
0
def try_perturbed_generated_dataset():
    torch.manual_seed(1234)
    ovd = OvernightDatasetLoader().load()
    govd = PCFGDataset(OvernightPCFGBuilder()
                       .build(ovd[(None, None, lambda x: x in {"train", "valid"})]
                              .map(lambda f: f[1]).examples),
                       N=10000)

    print(govd[0])
    # print(govd[lambda x: True][0])

    # print(govd[:])
    # create vocab from pcfg
    vocab = build_vocab_from_pcfg(govd._pcfg)
    seqenc = SequenceEncoder(vocab=vocab, tokenizer=tree_to_lisp_tokens)
    spanmasker = SpanMasker(seed=12345667)
    treemasker = SubtreeMasker(p=.05, seed=2345677)

    perturbed_govd = govd.cache()\
        .map(lambda x: (seqenc.convert(x, "tensor"), x)) \
        .map(lambda x: x + (seqenc.convert(x[-1], "tokens"),)) \
        .map(lambda x: x + (spanmasker(x[-1]),)) \
        .map(lambda x: x + (seqenc.convert(x[-1], "tensor"),)) \
        .map(lambda x: (x[-1], x[0]))

    dl = DataLoader(perturbed_govd, batch_size=10, shuffle=True, collate_fn=pad_and_default_collate)
    batch = next(iter(dl))
    print(batch)
    print(vocab.tostr(batch[0][1]))
    print(vocab.tostr(batch[1][1]))

    tt = q.ticktock()
    tt.tick("first run")
    for i in range(10000):
        y = perturbed_govd[i]
        if i < 10:
            print(f"{y[0]}\n{y[-2]}")
    tt.tock("first run done")
    tt.tick("second run")
    for i in range(10000):
        y = perturbed_govd[i]
        if i < 10:
            print(f"{y[0]}\n{y[-2]}")
    tt.tock("second run done")
Пример #4
0
class GeoDatasetRank(object):
    def __init__(self,
                 p="geoquery_gen/run4/",
                 min_freq: int = 2,
                 splits=None,
                 **kw):
        super(GeoDatasetRank, self).__init__(**kw)
        self._initialize(p)
        self.splits_proportions = splits

    def _initialize(self, p):
        self.data = {}
        with open(os.path.join(p, "trainpreds.json")) as f:
            trainpreds = ujson.load(f)
        with open(os.path.join(p, "testpreds.json")) as f:
            testpreds = ujson.load(f)
        splits = ["train"] * len(trainpreds) + ["test"] * len(testpreds)
        preds = trainpreds + testpreds

        self.sentence_encoder = SequenceEncoder(tokenizer=lambda x: x.split())
        self.query_encoder = SequenceEncoder(tokenizer=lambda x: x.split())

        # build vocabularies
        for i, (example, split) in enumerate(zip(preds, splits)):
            self.sentence_encoder.inc_build_vocab(" ".join(
                example["sentence"]),
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(" ".join(example["gold"]),
                                               seen=split == "train")
            for can in example["candidates"]:
                self.query_encoder.inc_build_vocab(" ".join(can["tokens"]),
                                                   seen=False)
        # for word, wordid in self.sentence_encoder.vocab.D.items():
        #     self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab()
        self.query_encoder.finalize_vocab()

        self.build_data(preds, splits)

    def build_data(self, examples: Iterable[dict], splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        for example, split in zip(examples, splits):
            inp, out = " ".join(example["sentence"]), " ".join(example["gold"])
            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(" ".join(example["gold"][:-1]))
            if not isinstance(gold_tree, Tree):
                assert (gold_tree is not None)
            gold_tensor, gold_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")

            candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], []
            candidate_align_entropies = []
            candidate_trees = []
            candidate_same = []
            for cand in example["candidates"]:
                cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]),
                                            None)
                if cand_tree is None:
                    cand_tree = Tree("@UNK@", [])
                assert (cand_tree is not None)
                cand_tensor, cand_tokens = self.query_encoder.convert(
                    " ".join(cand["tokens"]), return_what="tensor,tokens")
                candidate_tensors.append(cand_tensor)
                candidate_tokens.append(cand_tokens)
                candidate_align_tensors.append(torch.tensor(
                    cand["alignments"]))
                candidate_align_entropies.append(
                    torch.tensor(cand["align_entropies"]))
                candidate_trees.append(cand_tree)
                candidate_same.append(
                    are_equal_trees(cand_tree,
                                    gold_tree,
                                    orderless={"and", "or"},
                                    unktoken="@NOUNKTOKENHERE@"))

            candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0),
                                           0)
            candidate_align_tensor = torch.stack(
                q.pad_tensors(candidate_align_tensors, 0), 0)
            candidate_align_entropy = torch.stack(
                q.pad_tensors(candidate_align_entropies, 0), 0)
            candidate_same = torch.tensor(candidate_same)

            state = RankState(
                inp_tensor[None, :],
                gold_tensor[None, :],
                candidate_tensor[None, :, :],
                candidate_same[None, :],
                candidate_align_tensor[None, :],
                candidate_align_entropy[None, :],
                self.sentence_encoder.vocab,
                self.query_encoder.vocab,
            )
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, candidate_tensor.size(-1),
                             gold_tensor.size(-1))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out

    def get_split(self, split: str):
        return DatasetSplitProxy(self.data[split])

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
            goldmaxlen = max(goldmaxlen, state.candtensors.size(-1))
        inp_tensors = q.pad_tensors([state.inp_tensor for state in data], 1, 0)
        gold_tensors = q.pad_tensors([state.gold_tensor for state in data], 1,
                                     0)
        candtensors = q.pad_tensors([state.candtensors for state in data], 2,
                                    0)
        alignments = q.pad_tensors([state.alignments for state in data], 2, 0)
        alignment_entropies = q.pad_tensors(
            [state.alignment_entropies for state in data], 2, 0)

        for i, state in enumerate(data):
            state.inp_tensor = inp_tensors[i]
            state.gold_tensor = gold_tensors[i]
            state.candtensors = candtensors[i]
            state.alignments = alignments[i]
            state.alignment_entropies = alignment_entropies[i]
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5, shuffle=None):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize,
                                             split=split,
                                             shuffle=shuffle)
            return ret
        else:
            assert (split in self.data.keys())
            shuffle = shuffle if shuffle is not None else split in (
                "train", "train+valid")
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=shuffle,
                            collate_fn=type(self).collate_fn)
            return dl
Пример #5
0
class LCQuaDnoENTDataset(object):
    def __init__(self,
                 p="../../datasets/lcquad/",
                 sentence_encoder: SequenceEncoder = None,
                 min_freq: int = 2,
                 splits=None,
                 **kw):
        super(LCQuaDnoENTDataset, self).__init__(**kw)
        self._simplify_filters = True  # if True, filter expressions are converted to orderless and-expressions
        self._initialize(p, sentence_encoder, min_freq)
        self.splits_proportions = splits

    def lines_to_examples(self, lines: List[str]):
        maxsize_before = 0
        avgsize_before = []
        maxsize_after = 0
        avgsize_after = []
        afterstring = set()

        def convert_to_lispstr(_x):
            splits = _x.split()
            assert (sum([1 if xe == "~" else 0 for xe in splits]) == 1)
            assert (splits[1] == "~")
            splits = ["," if xe == "&" else xe for xe in splits]
            pstr = f"{splits[0]} ({' '.join(splits[2:])})"
            return pstr

        ret = []
        ltp = None
        j = 0
        for i, line in enumerate(lines):
            question = line["question"]
            query = line["logical_form"]
            query = convert_to_lispstr(query)
            z, ltp = prolog_to_pas(query, ltp)
            if z is not None:
                ztree = pas_to_tree(z)
                maxsize_before = max(maxsize_before, tree_size(ztree))
                avgsize_before.append(tree_size(ztree))
                lf = ztree
                ret.append((question, lf))
                # print(f"Example {j}:")
                # print(ret[-1][0])
                # print(ret[-1][1])
                # print()
                ltp = None
                maxsize_after = max(maxsize_after, tree_size(lf))
                avgsize_after.append(tree_size(lf))
                j += 1

        avgsize_before = sum(avgsize_before) / len(avgsize_before)
        avgsize_after = sum(avgsize_after) / len(avgsize_after)

        print("Sizes ({j} examples):")
        # print(f"\t Max, Avg size before: {maxsize_before}, {avgsize_before}")
        print(f"\t Max, Avg size: {maxsize_after}, {avgsize_after}")

        return ret

    def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int):
        self.data = {}
        self.sentence_encoder = sentence_encoder

        jp = os.path.join(p, "lcquad_dataset.json")
        with open(jp, "r") as f:
            examples = ujson.load(f)

        examples = self.lines_to_examples(examples)

        questions, queries = tuple(zip(*examples))
        trainlen = int(round(0.8 * len(examples)))
        validlen = int(round(0.1 * len(examples)))
        testlen = int(round(0.1 * len(examples)))
        splits = ["train"] * trainlen + ["valid"] * validlen + ["test"
                                                                ] * testlen
        random.seed(1337)
        random.shuffle(splits)
        assert (len(splits) == len(examples))

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            tree_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question,
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        for word, wordid in self.sentence_encoder.vocab.D.items():
            self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        self.build_data(questions, queries, splits)

    def build_data(self, inputs: Iterable[str], outputs: Iterable[str],
                   splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        eid = 0

        gold_map = torch.arange(
            0, self.query_encoder.vocab.number_of_ids(last_nonrare=False))
        rare_tokens = self.query_encoder.vocab.rare_tokens - set(
            self.sentence_encoder.vocab.D.keys())
        for rare_token in rare_tokens:
            gold_map[self.query_encoder.vocab[rare_token]] = \
                self.query_encoder.vocab[self.query_encoder.vocab.unktoken]

        for inp, out, split in zip(inputs, outputs, splits):
            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            out_tensor, out_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")
            out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [out], inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens], self.sentence_encoder.vocab,
                                     self.query_encoder.vocab)
            state.eids = np.asarray([eid], dtype="int64")
            maxlen_in, maxlen_out = max(maxlen_in,
                                        len(state.inp_tokens[0])), max(
                                            maxlen_out,
                                            len(state.gold_tokens[0]))
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            eid += 1
        self.maxlen_input, self.maxlen_output = maxlen_in, maxlen_out

    def get_split(self, split: str):
        splits = split.split("+")
        data = []
        for split in splits:
            data += self.data[split]
        return DatasetSplitProxy(data)

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize, split=split)
            return ret
        else:
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=split in ("train", "train+valid"),
                            collate_fn=type(self).collate_fn)
            return dl
Пример #6
0
class GeoDataset(object):
    def __init__(self,
                 p="../../datasets/geo880dong/",
                 sentence_encoder:SequenceEncoder=None,
                 min_freq:int=2,
                 splits=None, **kw):
        super(GeoDataset, self).__init__(**kw)
        self._initialize(p, sentence_encoder, min_freq)
        self.splits_proportions = splits

    def _initialize(self, p, sentence_encoder:SequenceEncoder, min_freq:int):
        self.data = {}
        self.sentence_encoder = sentence_encoder
        trainlines = [x.strip() for x in open(os.path.join(p, "train.txt"), "r").readlines()]
        testlines = [x.strip() for x in open(os.path.join(p, "test.txt"), "r").readlines()]
        splits = ["train"]*len(trainlines) + ["test"] * len(testlines)
        questions, queries = zip(*[x.split("\t") for x in trainlines])
        testqs, testxs = zip(*[x.split("\t") for x in testlines])
        questions += testqs
        queries += testxs

        self.query_encoder = SequenceEncoder(tokenizer=partial(basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True)

        # build vocabularies
        for i, (question, query, split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question, seen=split=="train")
            self.query_encoder.inc_build_vocab(query, seen=split=="train")
        # for word, wordid in self.sentence_encoder.vocab.D.items():
        #     self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        self.build_data(questions, queries, splits)

    def build_data(self, inputs:Iterable[str], outputs:Iterable[str], splits:Iterable[str], unktokens:Set[str]=None):
        gold_map = None
        maxlen_in, maxlen_out = 0, 0
        if unktokens is not None:
            gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids(last_nonrare=False))
            for rare_token in unktokens:
                gold_map[self.query_encoder.vocab[rare_token]] = \
                    self.query_encoder.vocab[self.query_encoder.vocab.unktoken]
        for inp, out, split in zip(inputs, outputs, splits):

            inp_tensor, inp_tokens = self.sentence_encoder.convert(inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(out)
            assert(gold_tree is not None)
            out_tensor, out_tokens = self.query_encoder.convert(out, return_what="tensor,tokens")
            if gold_map is not None:
                out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [gold_tree],
                                      inp_tensor[None, :], out_tensor[None, :],
                                      [inp_tokens], [out_tokens],
                                      self.sentence_encoder.vocab, self.query_encoder.vocab)
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, len(out_tensor))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out

    def get_split(self, split:str):
        return DatasetSplitProxy(self.data[split])

    @staticmethod
    def collate_fn(data:Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(1, goldmaxlen - state.gold_tensor.size(1))], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(1, inpmaxlen - state.inp_tensor.size(1))], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split:str=None, batsize:int=5):
        if split is None:   # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize, split=split)
            return ret
        else:
            assert(split in self.data.keys())
            dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split=="train",
             collate_fn=type(self).collate_fn)
            return dl
Пример #7
0
class GeoDataset(object):
    def __init__(self,
                 p="../../datasets/geo880dong/",
                 sentence_encoder: SequenceEncoder = None,
                 min_freq: int = 2,
                 cvfolds=None,
                 testfold=None,
                 reorder_random=False,
                 **kw):
        super(GeoDataset, self).__init__(**kw)
        self.cvfolds, self.testfold = cvfolds, testfold
        self.reorder_random = reorder_random
        self._initialize(p, sentence_encoder, min_freq)

    def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int):
        self.data = {}
        self.sentence_encoder = sentence_encoder
        trainlines = [
            x.strip()
            for x in open(os.path.join(p, "train.txt"), "r").readlines()
        ]
        testlines = [
            x.strip()
            for x in open(os.path.join(p, "test.txt"), "r").readlines()
        ]
        if self.cvfolds is None:
            splits = ["train"] * len(trainlines) + ["test"] * len(testlines)
        else:
            cvsplit_len = len(trainlines) / self.cvfolds
            splits = []
            for i in range(0, self.cvfolds):
                splits += [i] * round(cvsplit_len * (i + 1) - len(splits))
            random.shuffle(splits)
            splits = [
                "valid" if x == self.testfold else "train" for x in splits
            ]
            splits = splits + ["test"] * len(testlines)
        questions, queries = zip(*[x.split("\t") for x in trainlines])
        testqs, testxs = zip(*[x.split("\t") for x in testlines])
        questions += testqs
        queries += testxs

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            basic_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question,
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        # for word, wordid in self.sentence_encoder.vocab.D.items():
        #     self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        token_specs = self.build_token_specs(queries)
        self.token_specs = token_specs

        self.build_data(questions, queries, splits)

    def build_token_specs(self, outputs: Iterable[str]):
        token_specs = dict()

        def walk_the_tree(t, _ts):
            l = t.label()
            if l not in _ts:
                _ts[l] = [np.infty, -np.infty]
            minc, maxc = _ts[l]
            _ts[l] = [min(minc, len(t)), max(maxc, len(t))]
            for c in t:
                walk_the_tree(c, _ts)

        for out in outputs:
            out_tokens = self.query_encoder.convert(out,
                                                    return_what="tokens")[0]
            assert (out_tokens[-1] == "@END@")
            out_tokens = out_tokens[:-1]
            out_str = " ".join(out_tokens)
            tree = lisp_to_tree(out_str)
            walk_the_tree(tree, token_specs)

        token_specs["and"][1] = np.infty

        return token_specs

    def build_data(self,
                   inputs: Iterable[str],
                   outputs: Iterable[str],
                   splits: Iterable[str],
                   unktokens: Set[str] = None):
        gold_map = None
        maxlen_in, maxlen_out = 0, 0
        if unktokens is not None:
            gold_map = torch.arange(0,
                                    self.query_encoder.vocab.number_of_ids())
            for rare_token in unktokens:
                gold_map[self.query_encoder.vocab[rare_token]] = \
                    self.query_encoder.vocab[self.query_encoder.vocab.unktoken]
        for inp, out, split in zip(inputs, outputs, splits):

            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(out)
            assert (gold_tree is not None)
            out_tensor, out_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")
            if gold_map is not None:
                out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [gold_tree],
                                     inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens],
                                     self.sentence_encoder.vocab,
                                     self.query_encoder.vocab,
                                     token_specs=self.token_specs)
            if split == "train" and self.reorder_random is True:
                gold_tree_ = tensor2tree(out_tensor, self.query_encoder.vocab)
                random_gold_tree = random.choice(
                    get_tree_permutations(gold_tree_, orderless={"and"}))
                out_ = tree_to_lisp(random_gold_tree)
                out_tensor_, out_tokens_ = self.query_encoder.convert(
                    out_, return_what="tensor,tokens")
                if gold_map is not None:
                    out_tensor_ = gold_map[out_tensor_]
                state.gold_tensor = out_tensor_[None]

            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, len(out_tensor))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out

    def get_split(self, split: str):
        data = []
        for split_e in split.split("+"):
            data += self.data[split_e]
        return DatasetSplitProxy(data)

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5, shuffle=None):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize,
                                             split=split,
                                             shuffle=shuffle)
            return ret
        else:
            # assert(split in self.data.keys())
            shuffle = shuffle if shuffle is not None else split in (
                "train", "train+valid")
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=shuffle,
                            collate_fn=type(self).collate_fn)
            return dl
Пример #8
0
class GeoDataset(object):
    def __init__(self,
                 p="../../datasets/geo880_multiling/geoquery/",
                 train_lang="en",
                 test_lang=None,
                 bert_tokenizer=None,
                 min_freq: int = 2,
                 cvfolds=None,
                 testfold=None,
                 **kw):
        super(GeoDataset, self).__init__(**kw)
        self.train_lang = train_lang
        self.test_lang = test_lang if test_lang is not None else train_lang
        self.cvfolds, self.testfold = cvfolds, testfold
        self._initialize(p, bert_tokenizer, min_freq)

    def _initialize(self, p, bert_tokenizer, min_freq: int):
        self.data = {}
        self.bert_vocab = Vocab()
        self.bert_vocab.set_dict(bert_tokenizer.vocab)
        self.sentence_encoder = SequenceEncoder(
            lambda x: bert_tokenizer.tokenize(f"[CLS] {x} [SEP]"),
            vocab=self.bert_vocab)
        trainlines = [
            x for x in ujson.load(
                open(os.path.join(p, f"geo-{self.train_lang}.json"), "r"))
        ]
        testlines = [
            x for x in ujson.load(
                open(os.path.join(p, f"geo-{self.train_lang}.json"), "r"))
        ]
        trainlines = [x for x in trainlines if x["split"] == "train"]
        testlines = [x for x in testlines if x["split"] == "test"]
        if self.cvfolds is None:
            splits = ["train"] * len(trainlines) + ["test"] * len(testlines)
        else:
            cvsplit_len = len(trainlines) / self.cvfolds
            splits = []
            for i in range(0, self.cvfolds):
                splits += [i] * round(cvsplit_len * (i + 1) - len(splits))
            random.shuffle(splits)
            splits = [
                "valid" if x == self.testfold else "train" for x in splits
            ]
            splits = splits + ["test"] * len(testlines)
        questions = [x["nl"] for x in trainlines]
        queries = [x["mrl"] for x in trainlines]
        xquestions = [x["nl"] for x in testlines]
        xqueries = [x["mrl"] for x in testlines]
        questions += xquestions
        queries += xqueries

        # initialize output vocabulary
        outvocab = Vocab()
        for token, bertid in self.bert_vocab.D.items():
            outvocab.add_token(token, seen=False)

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            basic_query_tokenizer, strtok=bert_tokenizer),
                                             vocab=outvocab,
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        keeptokens = set(self.bert_vocab.D.keys())
        self.query_encoder.finalize_vocab(min_freq=min_freq,
                                          keep_tokens=keeptokens)

        token_specs = self.build_token_specs(queries)
        self.token_specs = token_specs

        self.build_data(questions, queries, splits)

    def build_token_specs(self, outputs: Iterable[str]):
        token_specs = dict()

        def walk_the_tree(t, _ts):
            l = t.label()
            if l not in _ts:
                _ts[l] = [np.infty, -np.infty]
            minc, maxc = _ts[l]
            _ts[l] = [min(minc, len(t)), max(maxc, len(t))]
            for c in t:
                walk_the_tree(c, _ts)

        for out in outputs:
            out_tokens = self.query_encoder.convert(out,
                                                    return_what="tokens")[0]
            assert (out_tokens[-1] == "@END@")
            out_tokens = out_tokens[:-1]
            out_str = " ".join(out_tokens)
            tree = lisp_to_tree(out_str)
            walk_the_tree(tree, token_specs)

        # token_specs["and"][1] = np.infty

        return token_specs

    def build_data(self, inputs: Iterable[str], outputs: Iterable[str],
                   splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        for inp, out, split in zip(inputs, outputs, splits):
            # tokenize both input and output
            inp_tokens = self.sentence_encoder.convert(inp,
                                                       return_what="tokens")[0]
            out_tokens = self.query_encoder.convert(out,
                                                    return_what="tokens")[0]
            # get gold tree
            gold_tree = lisp_to_tree(" ".join(out_tokens[:-1]))
            assert (gold_tree is not None)
            # replace words in output that can't be copied from given input to UNK tokens
            unktoken = self.query_encoder.vocab.unktoken
            inp_tokens_ = set(inp_tokens)
            out_tokens = [
                out_token if out_token in inp_tokens_ or
                (out_token in self.query_encoder.vocab
                 and not out_token in self.query_encoder.vocab.rare_tokens)
                else unktoken for out_token in out_tokens
            ]
            # convert token sequences to ids
            inp_tensor = self.sentence_encoder.convert(inp_tokens,
                                                       return_what="tensor")[0]
            out_tensor = self.query_encoder.convert(out_tokens,
                                                    return_what="tensor")[0]

            state = TreeDecoderState([inp], [gold_tree],
                                     inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens],
                                     self.sentence_encoder.vocab,
                                     self.query_encoder.vocab,
                                     token_specs=self.token_specs)

            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, len(out_tensor))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out

    def get_split(self, split: str):
        data = []
        for split_e in split.split("+"):
            data += self.data[split_e]
        return DatasetSplitProxy(data)

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5, shuffle=None):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize,
                                             split=split,
                                             shuffle=shuffle)
            return ret
        else:
            # assert(split in self.data.keys())
            shuffle = shuffle if shuffle is not None else split in (
                "train", "train+valid")
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=shuffle,
                            collate_fn=type(self).collate_fn)
            return dl
Пример #9
0
class GeoQueryDatasetFunQL(object):
    def __init__(self,
                 p="../../datasets/geoquery/",
                 sentence_encoder: SequenceEncoder = None,
                 min_freq: int = 2,
                 **kw):
        super(GeoQueryDatasetFunQL, self).__init__(**kw)
        self._initialize(p, sentence_encoder, min_freq)

    def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int):
        self.data = {}
        self.sentence_encoder = sentence_encoder
        questions = [
            x.strip()
            for x in open(os.path.join(p, "questions.txt"), "r").readlines()
        ]
        queries = [
            x.strip()
            for x in open(os.path.join(p, "queries.funql"), "r").readlines()
        ]
        trainidxs = set([
            int(x.strip()) for x in open(os.path.join(p, "train_indexes.txt"),
                                         "r").readlines()
        ])
        testidxs = set([
            int(x.strip()) for x in open(os.path.join(p, "test_indexes.txt"),
                                         "r").readlines()
        ])
        splits = [None] * len(questions)
        for trainidx in trainidxs:
            splits[trainidx] = "train"
        for testidx in testidxs:
            splits[testidx] = "test"
        if any([split == None for split in splits]):
            print(
                f"{len([split for split in splits if split == None])} examples not assigned to any split"
            )

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            basic_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        unktokens = set()
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            question_tokens = self.sentence_encoder.inc_build_vocab(
                question, seen=split == "train")
            query_tokens = self.query_encoder.inc_build_vocab(
                query, seen=split == "train")
            unktokens |= set(query_tokens) - set(question_tokens)
        for word in self.sentence_encoder.vocab.counts.keys():
            self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True)
        self.query_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True)
        unktokens = unktokens & self.query_encoder.vocab.rare_tokens

        self.build_data(questions, queries, splits, unktokens=unktokens)

    def build_data(self,
                   inputs: Iterable[str],
                   outputs: Iterable[str],
                   splits: Iterable[str],
                   unktokens: Set[str] = None):
        if unktokens is not None:
            gold_map = torch.arange(
                0, self.query_encoder.vocab.number_of_ids(last_nonrare=False))
            for rare_token in unktokens:
                gold_map[self.query_encoder.vocab[rare_token]] = \
                    self.query_encoder.vocab[self.query_encoder.vocab.unktoken]
        for inp, out, split in zip(inputs, outputs, splits):

            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = prolog_to_tree(out)
            assert (gold_tree is not None)
            out_tensor, out_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")
            if gold_map is not None:
                out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens], self.sentence_encoder.vocab,
                                     self.query_encoder.vocab)
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)

    def get_split(self, split: str):
        return DatasetSplitProxy(self.data[split])

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize, split=split)
            return ret
        else:
            assert (split in self.data.keys())
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=split == "train",
                            collate_fn=GeoQueryDatasetFunQL.collate_fn)
            return dl