Пример #1
0
    def test_tf_decoder(self):
        texts = [
            "i went to chocolate @END@", "awesome is @END@",
            "the meaning of life @END@"
        ]
        se = SequenceEncoder(tokenizer=lambda x: x.split())
        for t in texts:
            se.inc_build_vocab(t)
        se.finalize_vocab()
        x = BasicDecoderState(texts, texts, se, se)

        class Model(TransitionModel):
            def forward(self, x: BasicDecoderState):
                outprobs = torch.rand(len(x),
                                      x.query_encoder.vocab.number_of_ids())
                return outprobs, x

        dec = SeqDecoder(TFTransition(Model()))

        y = dec(x)
        print(y[1].followed_actions)
        outactions = y[1].followed_actions.detach().cpu().numpy()
        print(outactions[0])
        print(se.vocab.print(outactions[0]))
        print(se.vocab.print(outactions[1]))
        print(se.vocab.print(outactions[2]))
        self.assertTrue(se.vocab.print(outactions[0]) == texts[0])
        self.assertTrue(se.vocab.print(outactions[1]) == texts[1])
        self.assertTrue(se.vocab.print(outactions[2]) == texts[2])
Пример #2
0
    def test_beam_search(self):
        texts = [
            "i went to chocolate @END@", "awesome is @END@",
            "the meaning of life @END@"
        ]
        from parseq.vocab import SequenceEncoder
        se = SequenceEncoder(tokenizer=lambda x: x.split())
        for t in texts:
            se.inc_build_vocab(t)
        se.finalize_vocab()
        x = BasicDecoderState(texts, texts, se, se)
        x.start_decoding()

        class Model(TransitionModel):
            def forward(self, x: BasicDecoderState):
                outprobs = torch.randn(len(x),
                                       x.query_encoder.vocab.number_of_ids())
                outprobs = torch.nn.functional.log_softmax(outprobs, -1)
                return outprobs, x

        model = Model()

        beamsize = 50
        maxtime = 10
        bs = BeamDecoder(model,
                         eval=[CELoss(ignore_index=0),
                               SeqAccuracies()],
                         eval_beam=[BeamSeqAccuracies()],
                         beamsize=beamsize,
                         maxtime=maxtime)

        y = bs(x)
        print(y)
Пример #3
0
    def test_free_decoder(self):
        texts = [
            "i went to chocolate a b c d e f g h i j k l m n o p q r @END@",
            "awesome is @END@", "the meaning of life @END@"
        ]
        se = SequenceEncoder(tokenizer=lambda x: x.split())
        for t in texts:
            se.inc_build_vocab(t)
        se.finalize_vocab()

        texts = ["@END@"] * 100

        x = BasicDecoderState(texts, texts, se, se)

        class Model(TransitionModel):
            def forward(self, x: BasicDecoderState):
                outprobs = torch.rand(len(x),
                                      x.query_encoder.vocab.number_of_ids())
                return outprobs, x

        MAXTIME = 10
        dec = SeqDecoder(FreerunningTransition(Model(), maxtime=MAXTIME))

        y = dec(x)
        print(y[1].followed_actions)
        print(max([len(y[1].followed_actions[i]) for i in range(len(y[1]))]))
        print(min([len(y[1].followed_actions[i]) for i in range(len(y[1]))]))
        self.assertTrue(
            max([len(y[1].followed_actions[i])
                 for i in range(len(y[1]))]) <= MAXTIME + 1)
Пример #4
0
def load_ds(domain="restaurants",
            min_freq=0,
            top_k=np.infty,
            nl_mode="bart-large",
            trainonvalid=False):
    ds = OvernightDatasetLoader(simplify_mode="light").load(
        domain=domain, trainonvalid=trainonvalid)

    seqenc_vocab = Vocab(padid=1, startid=0, endid=2, unkid=UNKID)
    seqenc = SequenceEncoder(vocab=seqenc_vocab,
                             tokenizer=tree_to_lisp_tokens,
                             add_start_token=True,
                             add_end_token=True)
    for example in ds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=example[2] == "train")
    seqenc.finalize_vocab(min_freq=min_freq, top_k=top_k)

    nl_tokenizer = AutoTokenizer.from_pretrained(nl_mode)

    def tokenize(x):
        ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0],
               seqenc.convert(x[1], return_what="tensor"), x[2], x[0], x[1])
        return ret
    tds, vds, xds = ds[(None, None, "train")].map(tokenize), \
                    ds[(None, None, "valid")].map(tokenize), \
                    ds[(None, None, "test")].map(tokenize)
    return tds, vds, xds, nl_tokenizer, seqenc
Пример #5
0
    def test_beam_transition(self):
        texts = [
            "i went to chocolate @END@", "awesome is @END@",
            "the meaning of life @END@"
        ]
        from parseq.vocab import SequenceEncoder
        se = SequenceEncoder(tokenizer=lambda x: x.split())
        for t in texts:
            se.inc_build_vocab(t)
        se.finalize_vocab()
        x = BasicDecoderState(texts, texts, se, se)
        x.start_decoding()

        class Model(TransitionModel):
            def forward(self, x: BasicDecoderState):
                outprobs = torch.randn(len(x),
                                       x.query_encoder.vocab.number_of_ids())
                outprobs = torch.nn.functional.log_softmax(outprobs, -1)
                return outprobs, x

        model = Model()

        beamsize = 50
        maxtime = 10
        beam_xs = [
            x.make_copy(detach=False, deep=True) for _ in range(beamsize)
        ]
        beam_states = BeamState(beam_xs)

        print(len(beam_xs))
        print(len(beam_states))

        bt = BeamTransition(model, beamsize, maxtime=maxtime)
        i = 0
        _, _, y, _ = bt(x, i)
        i += 1
        _, _, y, _ = bt(y, i)

        all_terminated = y.all_terminated()
        while not all_terminated:
            _, predactions, y, all_terminated = bt(y, i)
            i += 1

        print("timesteps done:")
        print(i)
        print(y)
        print(predactions[0])
        for i in range(beamsize):
            print("-")
            # print(y.bstates[0].get(i).followed_actions)
            # print(predactions[0, i, :])
            pa = predactions[0, i, :]
            # print((pa == se.vocab[se.vocab.endtoken]).cumsum(0))
            pa = ((pa == se.vocab[se.vocab.endtoken]).long().cumsum(0) <
                  1).long() * pa
            yb = y.bstates[0].get(i).followed_actions[0, :]
            yb = yb * (yb != se.vocab[se.vocab.endtoken]).long()
            print(pa)
            print(yb)
            self.assertTrue(torch.allclose(pa, yb))
Пример #6
0
    def test_beam_search_vs_greedy(self):
        with torch.no_grad():
            texts = ["a b"] * 10
            from parseq.vocab import SequenceEncoder
            se = SequenceEncoder(tokenizer=lambda x: x.split())
            for t in texts:
                se.inc_build_vocab(t)
            se.finalize_vocab()
            x = BasicDecoderState(texts, texts, se, se)
            x.start_decoding()

            class Model(TransitionModel):
                transition_tensor = torch.tensor([[0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .51, .49],
                                                  [0, 0, 0, 0, .01, .99]])

                def forward(self, x: BasicDecoderState):
                    prev = x.prev_actions
                    outprobs = self.transition_tensor[prev]
                    outprobs = torch.log(outprobs)
                    return outprobs, x

            model = Model()

            beamsize = 50
            maxtime = 10
            beam_xs = [
                x.make_copy(detach=False, deep=True) for _ in range(beamsize)
            ]
            beam_states = BeamState(beam_xs)

            print(len(beam_xs))
            print(len(beam_states))

            bt = BeamTransition(model, beamsize, maxtime=maxtime)
            i = 0
            _, _, y, _ = bt(x, i)
            i += 1
            _, _, y, _ = bt(y, i)

            all_terminated = y.all_terminated()
            while not all_terminated:
                start_time = time.time()
                _, _, y, all_terminated = bt(y, i)
                i += 1
                # print(i)
                end_time = time.time()
                print(f"{i}: {end_time - start_time}")

            print(y)
            print(y.bstates.get(0).followed_actions)
Пример #7
0
class GeoQueryDatasetSub(GeoQueryDatasetFunQL):
    def __init__(self,
                 p="../../datasets/geo880dong/",
                 sentence_encoder: SequenceEncoder = None,
                 min_freq: int = 2,
                 **kw):
        super(GeoQueryDatasetSub, self).__init__(p, sentence_encoder, min_freq,
                                                 **kw)

    def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int):
        self.data = {}
        self.sentence_encoder = sentence_encoder
        trainlines = [
            x.strip()
            for x in open(os.path.join(p, "train.txt"), "r").readlines()
        ]
        testlines = [
            x.strip()
            for x in open(os.path.join(p, "test.txt"), "r").readlines()
        ]
        splits = ["train"] * len(trainlines) + ["test"] * len(testlines)
        questions, queries = zip(*[x.split("\t") for x in trainlines])
        testqs, testxs = zip(*[x.split("\t") for x in testlines])
        questions += testqs
        queries += testxs

        queries = self.lisp2prolog(queries)

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            basic_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question,
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        for word, wordid in self.sentence_encoder.vocab.D.items():
            self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        self.build_data(questions, queries, splits)

    def lisp2prolog(self, data: List[str]):
        ret = []
        for x in data:
            pas = lisp_to_pas(x)
            prolog = pas_to_prolog(pas)
            ret.append(prolog)
        return ret
Пример #8
0
class ConditionalRecallDataset(object):
    def __init__(self, maxlen=10, NperY=10, **kw):
        super(ConditionalRecallDataset, self).__init__(**kw)
        self.data = {}
        self.NperY, self.maxlen = NperY, maxlen
        self._seqs, self._ys = gen_data(self.maxlen, self.NperY)
        self.encoder = SequenceEncoder(tokenizer=lambda x: list(x))

        for seq, y in zip(self._seqs, self._ys):
            self.encoder.inc_build_vocab(seq)
            self.encoder.inc_build_vocab(y)

        self.N = len(self._seqs)
        N = self.N

        splits = ["train"] * int(N * 0.8) + ["valid"] * int(
            N * 0.1) + ["test"] * int(N * 0.1)
        random.shuffle(splits)

        self.encoder.finalize_vocab()
        self.build_data(self._seqs, self._ys, splits)

    def build_data(self, seqs, ys, splits):
        for seq, y, split in zip(seqs, ys, splits):
            seq_tensor = self.encoder.convert(seq, return_what="tensor")
            y_tensor = self.encoder.convert(y, return_what="tensor")
            if split not in self.data:
                self.data[split] = []
            self.data[split].append((seq_tensor[0], y_tensor[0][0]))

    def get_split(self, split: str):
        return DatasetSplitProxy(self.data[split])

    def dataloader(self, split: str = None, batsize: int = 5, shuffle=None):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize,
                                             split=split,
                                             shuffle=shuffle)
            return ret
        else:
            assert (split in self.data.keys())
            shuffle = shuffle if shuffle is not None else split in (
                "train", "train+valid")
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=shuffle)
            return dl
Пример #9
0
 def test_decoder_API(self):
     texts = ["i went to chocolate", "awesome is", "the meaning of life"]
     se = SequenceEncoder(tokenizer=lambda x: x.split())
     for t in texts:
         se.inc_build_vocab(t)
     se.finalize_vocab()
     x = BasicDecoderState(texts, texts, se, se)
     print(x.inp_tensor)
     print("terminated")
     print(x.is_terminated())
     print(x.all_terminated())
     print("prev_actions")
     x.start_decoding()
     print(x.prev_actions)
     print("step")
     x.step(["i", torch.tensor([7]), "the"])
     print(x.prev_actions)
     print(x.followed_actions)
Пример #10
0
def try_tokenizer_dataset():
    from transformers import BartTokenizer

    ovd = OvernightDatasetLoader().load()
    seqenc = SequenceEncoder(tokenizer=tree_to_lisp_tokens)
    for example in ovd.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=example[2] == "train")
    seqenc.finalize_vocab()
    nl_tokenizer = BartTokenizer.from_pretrained("bart-large")
    def tokenize(x):
        ret = [xe for xe in x]
        ret.append(nl_tokenizer.tokenize(ret[0]))
        ret.append(nl_tokenizer.encode(ret[0], return_tensors="pt"))
        ret.append(seqenc.convert(ret[1], return_what="tensor")[0][None])
        return ret
    ovd = ovd.map(tokenize)
    print(ovd[0])
Пример #11
0
    def test_tf_decoder_with_losses_with_gold(self):
        texts = [
            "i went to chocolate @END@", "awesome is @END@",
            "the meaning of life @END@"
        ]
        se = SequenceEncoder(tokenizer=lambda x: x.split())
        for t in texts:
            se.inc_build_vocab(t)
        se.finalize_vocab()
        x = BasicDecoderState(texts, texts, se, se)

        class Model(TransitionModel):
            def forward(self, x: BasicDecoderState):
                outprobs = torch.zeros(len(x),
                                       x.query_encoder.vocab.number_of_ids())
                golds = x.get_gold().gather(
                    1,
                    torch.tensor(x._timesteps).to(torch.long)[:, None])
                outprobs.scatter_(1, golds, 1)
                return outprobs, x

        celoss = CELoss(ignore_index=0)
        accs = SeqAccuracies()

        dec = SeqDecoder(TFTransition(Model()), eval=[celoss, accs])

        y = dec(x)

        print(y[0])
        print(y[1].followed_actions)
        print(y[1].get_gold())

        self.assertEqual(y[0]["seq_acc"], 1)
        self.assertEqual(y[0]["elem_acc"], 1)

        # print(y[1].followed_actions)
        outactions = y[1].followed_actions.detach().cpu().numpy()
        # print(outactions[0])
        # print(se.vocab.print(outactions[0]))
        # print(se.vocab.print(outactions[1]))
        # print(se.vocab.print(outactions[2]))
        self.assertTrue(se.vocab.print(outactions[0]) == texts[0])
        self.assertTrue(se.vocab.print(outactions[1]) == texts[1])
        self.assertTrue(se.vocab.print(outactions[2]) == texts[2])
Пример #12
0
 def test_create(self):
     se = SequenceEncoder(tokenizer=lambda x: x.split())
     texts = [
         "i went to chocolate", "awesome is @PAD@ @PAD@",
         "the meaning of life"
     ]
     for t in texts:
         se.inc_build_vocab(t)
     se.finalize_vocab()
     x = [BasicDecoderState([t], [t], se, se) for t in texts]
     merged_x = x[0].merge(x)
     texts = ["i went to chocolate", "awesome is", "the meaning of life"]
     batch_x = BasicDecoderState(texts, texts, se, se)
     print(merged_x.inp_tensor)
     print(batch_x.inp_tensor)
     self.assertTrue(torch.allclose(merged_x.inp_tensor,
                                    batch_x.inp_tensor))
     self.assertTrue(
         torch.allclose(merged_x.gold_tensor, batch_x.gold_tensor))
Пример #13
0
    def test_tf_decoder_with_losses(self):
        texts = [
            "i went to chocolate @END@", "awesome is @END@",
            "the meaning of life @END@"
        ]
        se = SequenceEncoder(tokenizer=lambda x: x.split())
        for t in texts:
            se.inc_build_vocab(t)
        se.finalize_vocab()
        x = BasicDecoderState(texts, texts, se, se)

        class Model(TransitionModel):
            def forward(self, x: BasicDecoderState):
                outprobs = torch.rand(len(x),
                                      x.query_encoder.vocab.number_of_ids())
                outprobs = torch.nn.functional.log_softmax(outprobs, -1)
                return outprobs, x

        celoss = CELoss(ignore_index=0)
        accs = SeqAccuracies()

        dec = SeqDecoder(TFTransition(Model()), eval=[celoss, accs])

        y = dec(x)

        print(y[0])
        print(y[1].followed_actions)
        print(y[1].get_gold())

        # print(y[1].followed_actions)
        outactions = y[1].followed_actions.detach().cpu().numpy()
        # print(outactions[0])
        # print(se.vocab.print(outactions[0]))
        # print(se.vocab.print(outactions[1]))
        # print(se.vocab.print(outactions[2]))
        self.assertTrue(se.vocab.print(outactions[0]) == texts[0])
        self.assertTrue(se.vocab.print(outactions[1]) == texts[1])
        self.assertTrue(se.vocab.print(outactions[2]) == texts[2])
Пример #14
0
class GeoDatasetRank(object):
    def __init__(self,
                 p="geoquery_gen/run4/",
                 min_freq: int = 2,
                 splits=None,
                 **kw):
        super(GeoDatasetRank, self).__init__(**kw)
        self._initialize(p)
        self.splits_proportions = splits

    def _initialize(self, p):
        self.data = {}
        with open(os.path.join(p, "trainpreds.json")) as f:
            trainpreds = ujson.load(f)
        with open(os.path.join(p, "testpreds.json")) as f:
            testpreds = ujson.load(f)
        splits = ["train"] * len(trainpreds) + ["test"] * len(testpreds)
        preds = trainpreds + testpreds

        self.sentence_encoder = SequenceEncoder(tokenizer=lambda x: x.split())
        self.query_encoder = SequenceEncoder(tokenizer=lambda x: x.split())

        # build vocabularies
        for i, (example, split) in enumerate(zip(preds, splits)):
            self.sentence_encoder.inc_build_vocab(" ".join(
                example["sentence"]),
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(" ".join(example["gold"]),
                                               seen=split == "train")
            for can in example["candidates"]:
                self.query_encoder.inc_build_vocab(" ".join(can["tokens"]),
                                                   seen=False)
        # for word, wordid in self.sentence_encoder.vocab.D.items():
        #     self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab()
        self.query_encoder.finalize_vocab()

        self.build_data(preds, splits)

    def build_data(self, examples: Iterable[dict], splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        for example, split in zip(examples, splits):
            inp, out = " ".join(example["sentence"]), " ".join(example["gold"])
            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(" ".join(example["gold"][:-1]))
            if not isinstance(gold_tree, Tree):
                assert (gold_tree is not None)
            gold_tensor, gold_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")

            candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], []
            candidate_align_entropies = []
            candidate_trees = []
            candidate_same = []
            for cand in example["candidates"]:
                cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]),
                                            None)
                if cand_tree is None:
                    cand_tree = Tree("@UNK@", [])
                assert (cand_tree is not None)
                cand_tensor, cand_tokens = self.query_encoder.convert(
                    " ".join(cand["tokens"]), return_what="tensor,tokens")
                candidate_tensors.append(cand_tensor)
                candidate_tokens.append(cand_tokens)
                candidate_align_tensors.append(torch.tensor(
                    cand["alignments"]))
                candidate_align_entropies.append(
                    torch.tensor(cand["align_entropies"]))
                candidate_trees.append(cand_tree)
                candidate_same.append(
                    are_equal_trees(cand_tree,
                                    gold_tree,
                                    orderless={"and", "or"},
                                    unktoken="@NOUNKTOKENHERE@"))

            candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0),
                                           0)
            candidate_align_tensor = torch.stack(
                q.pad_tensors(candidate_align_tensors, 0), 0)
            candidate_align_entropy = torch.stack(
                q.pad_tensors(candidate_align_entropies, 0), 0)
            candidate_same = torch.tensor(candidate_same)

            state = RankState(
                inp_tensor[None, :],
                gold_tensor[None, :],
                candidate_tensor[None, :, :],
                candidate_same[None, :],
                candidate_align_tensor[None, :],
                candidate_align_entropy[None, :],
                self.sentence_encoder.vocab,
                self.query_encoder.vocab,
            )
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, candidate_tensor.size(-1),
                             gold_tensor.size(-1))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out

    def get_split(self, split: str):
        return DatasetSplitProxy(self.data[split])

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
            goldmaxlen = max(goldmaxlen, state.candtensors.size(-1))
        inp_tensors = q.pad_tensors([state.inp_tensor for state in data], 1, 0)
        gold_tensors = q.pad_tensors([state.gold_tensor for state in data], 1,
                                     0)
        candtensors = q.pad_tensors([state.candtensors for state in data], 2,
                                    0)
        alignments = q.pad_tensors([state.alignments for state in data], 2, 0)
        alignment_entropies = q.pad_tensors(
            [state.alignment_entropies for state in data], 2, 0)

        for i, state in enumerate(data):
            state.inp_tensor = inp_tensors[i]
            state.gold_tensor = gold_tensors[i]
            state.candtensors = candtensors[i]
            state.alignments = alignments[i]
            state.alignment_entropies = alignment_entropies[i]
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5, shuffle=None):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize,
                                             split=split,
                                             shuffle=shuffle)
            return ret
        else:
            assert (split in self.data.keys())
            shuffle = shuffle if shuffle is not None else split in (
                "train", "train+valid")
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=shuffle,
                            collate_fn=type(self).collate_fn)
            return dl
Пример #15
0
class LCQuaDnoENTDataset(object):
    def __init__(self,
                 p="../../datasets/lcquad/",
                 sentence_encoder: SequenceEncoder = None,
                 min_freq: int = 2,
                 splits=None,
                 **kw):
        super(LCQuaDnoENTDataset, self).__init__(**kw)
        self._simplify_filters = True  # if True, filter expressions are converted to orderless and-expressions
        self._initialize(p, sentence_encoder, min_freq)
        self.splits_proportions = splits

    def lines_to_examples(self, lines: List[str]):
        maxsize_before = 0
        avgsize_before = []
        maxsize_after = 0
        avgsize_after = []
        afterstring = set()

        def convert_to_lispstr(_x):
            splits = _x.split()
            assert (sum([1 if xe == "~" else 0 for xe in splits]) == 1)
            assert (splits[1] == "~")
            splits = ["," if xe == "&" else xe for xe in splits]
            pstr = f"{splits[0]} ({' '.join(splits[2:])})"
            return pstr

        ret = []
        ltp = None
        j = 0
        for i, line in enumerate(lines):
            question = line["question"]
            query = line["logical_form"]
            query = convert_to_lispstr(query)
            z, ltp = prolog_to_pas(query, ltp)
            if z is not None:
                ztree = pas_to_tree(z)
                maxsize_before = max(maxsize_before, tree_size(ztree))
                avgsize_before.append(tree_size(ztree))
                lf = ztree
                ret.append((question, lf))
                # print(f"Example {j}:")
                # print(ret[-1][0])
                # print(ret[-1][1])
                # print()
                ltp = None
                maxsize_after = max(maxsize_after, tree_size(lf))
                avgsize_after.append(tree_size(lf))
                j += 1

        avgsize_before = sum(avgsize_before) / len(avgsize_before)
        avgsize_after = sum(avgsize_after) / len(avgsize_after)

        print("Sizes ({j} examples):")
        # print(f"\t Max, Avg size before: {maxsize_before}, {avgsize_before}")
        print(f"\t Max, Avg size: {maxsize_after}, {avgsize_after}")

        return ret

    def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int):
        self.data = {}
        self.sentence_encoder = sentence_encoder

        jp = os.path.join(p, "lcquad_dataset.json")
        with open(jp, "r") as f:
            examples = ujson.load(f)

        examples = self.lines_to_examples(examples)

        questions, queries = tuple(zip(*examples))
        trainlen = int(round(0.8 * len(examples)))
        validlen = int(round(0.1 * len(examples)))
        testlen = int(round(0.1 * len(examples)))
        splits = ["train"] * trainlen + ["valid"] * validlen + ["test"
                                                                ] * testlen
        random.seed(1337)
        random.shuffle(splits)
        assert (len(splits) == len(examples))

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            tree_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question,
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        for word, wordid in self.sentence_encoder.vocab.D.items():
            self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        self.build_data(questions, queries, splits)

    def build_data(self, inputs: Iterable[str], outputs: Iterable[str],
                   splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        eid = 0

        gold_map = torch.arange(
            0, self.query_encoder.vocab.number_of_ids(last_nonrare=False))
        rare_tokens = self.query_encoder.vocab.rare_tokens - set(
            self.sentence_encoder.vocab.D.keys())
        for rare_token in rare_tokens:
            gold_map[self.query_encoder.vocab[rare_token]] = \
                self.query_encoder.vocab[self.query_encoder.vocab.unktoken]

        for inp, out, split in zip(inputs, outputs, splits):
            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            out_tensor, out_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")
            out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [out], inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens], self.sentence_encoder.vocab,
                                     self.query_encoder.vocab)
            state.eids = np.asarray([eid], dtype="int64")
            maxlen_in, maxlen_out = max(maxlen_in,
                                        len(state.inp_tokens[0])), max(
                                            maxlen_out,
                                            len(state.gold_tokens[0]))
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            eid += 1
        self.maxlen_input, self.maxlen_output = maxlen_in, maxlen_out

    def get_split(self, split: str):
        splits = split.split("+")
        data = []
        for split in splits:
            data += self.data[split]
        return DatasetSplitProxy(data)

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize, split=split)
            return ret
        else:
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=split in ("train", "train+valid"),
                            collate_fn=type(self).collate_fn)
            return dl
Пример #16
0
def load_ds(traindomains=("restaurants",),
            testdomain="housing",
            min_freq=1,
            mincoverage=1,
            top_k=np.infty,
            nl_mode="bert-base-uncased",
            fullsimplify=False,
            onlyabstract=False,
            pretrainsetting="all+lex",    # "all", "lex" or "all+lex"
            finetunesetting="lex",        # "lex", "all", "min"
            ):
    """
    :param traindomains:
    :param testdomain:
    :param min_freq:
    :param mincoverage:
    :param top_k:
    :param nl_mode:
    :param fullsimplify:
    :param add_domain_start:
    :param onlyabstract:
    :param pretrainsetting:     "all": use all examples from every domain
                                "lex": use only lexical examples
                                "all+lex": use both
    :param finetunesetting:     "lex": use lexical examples
                                "all": use all training examples
                                "min": use minimal lexicon-covering set of examples
                            ! Test is always over the same original test set.
                            ! Validation is over a fraction of training data
    :return:
    """
    general_tokens = {
        "(", ")", "arg:~type", "arg:type", "op:and", "SW:concat", "cond:has",
        "arg:<=", "arg:<", "arg:>=", "arg:>", "arg:!=", "arg:=", "SW:superlative",
        "SW:CNT-arg:min", "SW:CNT-arg:<", "SW:CNT-arg:<=", "SW:CNT-arg:>=", "SW:CNT-arg:>",
        "SW:CNT-arg:max", "SW:CNT-arg:=", "arg:max",
    }

    def tokenize_and_add_start(t):
        tokens = tree_to_lisp_tokens(t)
        starttok = "@START@"
        tokens = [starttok] + tokens
        return tokens

    sourceex = []
    for traindomain in traindomains:
        ds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True,
                                    restore_reverse=DATA_RESTORE_REVERSE, validfrac=.10)\
            .load(domain=traindomain)
        sourceex += ds[(None, None, lambda x: x in ("train", "valid", "lexicon"))].map(lambda x: (x[0], x[1], x[2], traindomain)).examples       # don't use test examples

    testds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True, restore_reverse=DATA_RESTORE_REVERSE)\
        .load(domain=testdomain)

    targetex = testds.map(lambda x: x + (testdomain,)).examples

    pretrainex = []
    if "all" in pretrainsetting.split("+"):
        pretrainex += [(a, tokenize_and_add_start(b), "pretrain", d) for a, b, c, d in sourceex if c == "train"]
    if "lex" in pretrainsetting.split("+"):
        pretrainex += [(a, tokenize_and_add_start(b), "pretrain", d) for a, b, c, d in sourceex if c == "lexicon"]

    pretrainvalidex = [(a, tokenize_and_add_start(b), "pretrainvalid", d) for a, b, c, d in sourceex if c == "valid"]

    if finetunesetting == "all":
        finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in targetex if c == "train"]
    elif finetunesetting == "lex":
        finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in targetex if c == "lexicon"]
    elif finetunesetting == "min":
        finetunetrainex = get_maximum_spanning_examples([(a, b, c, d) for a, b, c, d in targetex if c == "train"],
                                      mincoverage=mincoverage,
                                      loadedex=[e for e in pretrainex if e[2] == "pretrain"])
        finetunetrainex = [(a, tokenize_and_add_start(b), "fttrain", d) for a, b, c, d in finetunetrainex]
    finetunevalidex = [(a, tokenize_and_add_start(b), "ftvalid", d) for a, b, c, d in targetex if c == "valid"]
    finetunetestex = [(a, tokenize_and_add_start(b), "fttest", d) for a, b, c, d in targetex if c == "test"]
    print(f"Using mode \"{finetunesetting}\" for finetuning data: "
          f"\n\t{len(finetunetrainex)} training examples")


    allex = pretrainex + pretrainvalidex + finetunetrainex + finetunevalidex + finetunetestex
    ds = Dataset(allex)

    if onlyabstract:
        et = get_lf_abstract_transform(ds[lambda x: x[3] != testdomain].examples)
        ds = ds.map(lambda x: (x[0], et(x[1]), x[2], x[3]))

    seqenc_vocab = Vocab(padid=0, startid=1, endid=2, unkid=UNKID)
    seqenc = SequenceEncoder(vocab=seqenc_vocab, tokenizer=lambda x: x,
                             add_start_token=False, add_end_token=True)
    for example in ds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=example[2] in ("pretrain", "fttrain"))
    seqenc.finalize_vocab(min_freq=min_freq, top_k=top_k)

    generaltokenmask = torch.zeros(seqenc_vocab.number_of_ids(), dtype=torch.long)
    for token, tokenid in seqenc_vocab.D.items():
        if token in general_tokens:
            generaltokenmask[tokenid] = 1

    nl_tokenizer = AutoTokenizer.from_pretrained(nl_mode)
    def tokenize(x):
        ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0],
               seqenc.convert(x[1], return_what="tensor"),
               x[2],
               x[0], x[1], x[3])
        return ret
    tds, ftds, vds, fvds, xds = ds[(None, None, "pretrain", None)].map(tokenize), \
                          ds[(None, None, "fttrain", None)].map(tokenize), \
                          ds[(None, None, "pretrainvalid", None)].map(tokenize), \
                          ds[(None, None, "ftvalid", None)].map(tokenize), \
                          ds[(None, None, "fttest", None)].map(tokenize)
    return tds, ftds, vds, fvds, xds, nl_tokenizer, seqenc, generaltokenmask
def load_ds(domain="restaurants",
            nl_mode="bert-base-uncased",
            trainonvalid=False,
            noreorder=False):
    """
    Creates a dataset of examples which have
    * NL question and tensor
    * original FL tree
    * reduced FL tree with slots (this is randomly generated)
    * tensor corresponding to reduced FL tree with slots
    * mask specifying which elements in reduced FL tree are terminated
    * 2D gold that specifies whether a token/action is in gold for every position (compatibility with MML!)
    """
    orderless = {"op:and", "SW:concat"}  # only use in eval!!

    ds = OvernightDatasetLoader().load(domain=domain,
                                       trainonvalid=trainonvalid)
    ds = ds.map(lambda x: (x[0], ATree("@START@", [x[1]]), x[2]))

    if not noreorder:
        ds = ds.map(lambda x:
                    (x[0], reorder_tree(x[1], orderless=orderless), x[2]))

    vocab = Vocab(padid=0, startid=2, endid=3, unkid=1)
    vocab.add_token("@START@", seen=np.infty)
    vocab.add_token(
        "@CLOSE@", seen=np.infty
    )  # only here for the action of closing an open position, will not be seen at input
    vocab.add_token(
        "@OPEN@", seen=np.infty
    )  # only here for the action of opening a closed position, will not be seen at input
    vocab.add_token(
        "@REMOVE@", seen=np.infty
    )  # only here for deletion operations, won't be seen at input
    vocab.add_token(
        "@REMOVESUBTREE@", seen=np.infty
    )  # only here for deletion operations, won't be seen at input
    vocab.add_token("@SLOT@",
                    seen=np.infty)  # will be seen at input, can't be produced!

    nl_tokenizer = BertTokenizer.from_pretrained(nl_mode)
    # for tok, idd in nl_tokenizer.vocab.items():
    #     vocab.add_token(tok, seen=np.infty)          # all wordpieces are added for possible later generation

    tds, vds, xds = ds[lambda x: x[2] == "train"], \
                    ds[lambda x: x[2] == "valid"], \
                    ds[lambda x: x[2] == "test"]

    seqenc = SequenceEncoder(
        vocab=vocab,
        tokenizer=lambda x: extract_info(x, onlytokens=True),
        add_start_token=False,
        add_end_token=False)
    for example in tds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=True)
    for example in vds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=False)
    for example in xds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=False)
    seqenc.finalize_vocab(min_freq=0)

    def mapper(x):
        nl = x[0]
        fl = x[1]
        fltoks = extract_info(fl, onlytokens=True)
        seq = seqenc.convert(fltoks, return_what="tensor")
        ret = (nl_tokenizer.encode(nl, return_tensors="pt")[0], seq)
        return ret

    tds_seq = tds.map(mapper)
    vds_seq = vds.map(mapper)
    xds_seq = xds.map(mapper)
    return tds_seq, vds_seq, xds_seq, nl_tokenizer, seqenc, orderless
Пример #18
0
class GeoDataset(object):
    def __init__(self,
                 p="../../datasets/geo880dong/",
                 sentence_encoder:SequenceEncoder=None,
                 min_freq:int=2,
                 splits=None, **kw):
        super(GeoDataset, self).__init__(**kw)
        self._initialize(p, sentence_encoder, min_freq)
        self.splits_proportions = splits

    def _initialize(self, p, sentence_encoder:SequenceEncoder, min_freq:int):
        self.data = {}
        self.sentence_encoder = sentence_encoder
        trainlines = [x.strip() for x in open(os.path.join(p, "train.txt"), "r").readlines()]
        testlines = [x.strip() for x in open(os.path.join(p, "test.txt"), "r").readlines()]
        splits = ["train"]*len(trainlines) + ["test"] * len(testlines)
        questions, queries = zip(*[x.split("\t") for x in trainlines])
        testqs, testxs = zip(*[x.split("\t") for x in testlines])
        questions += testqs
        queries += testxs

        self.query_encoder = SequenceEncoder(tokenizer=partial(basic_query_tokenizer, strtok=sentence_encoder.tokenizer), add_end_token=True)

        # build vocabularies
        for i, (question, query, split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question, seen=split=="train")
            self.query_encoder.inc_build_vocab(query, seen=split=="train")
        # for word, wordid in self.sentence_encoder.vocab.D.items():
        #     self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        self.build_data(questions, queries, splits)

    def build_data(self, inputs:Iterable[str], outputs:Iterable[str], splits:Iterable[str], unktokens:Set[str]=None):
        gold_map = None
        maxlen_in, maxlen_out = 0, 0
        if unktokens is not None:
            gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids(last_nonrare=False))
            for rare_token in unktokens:
                gold_map[self.query_encoder.vocab[rare_token]] = \
                    self.query_encoder.vocab[self.query_encoder.vocab.unktoken]
        for inp, out, split in zip(inputs, outputs, splits):

            inp_tensor, inp_tokens = self.sentence_encoder.convert(inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(out)
            assert(gold_tree is not None)
            out_tensor, out_tokens = self.query_encoder.convert(out, return_what="tensor,tokens")
            if gold_map is not None:
                out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [gold_tree],
                                      inp_tensor[None, :], out_tensor[None, :],
                                      [inp_tokens], [out_tokens],
                                      self.sentence_encoder.vocab, self.query_encoder.vocab)
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, len(out_tensor))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out

    def get_split(self, split:str):
        return DatasetSplitProxy(self.data[split])

    @staticmethod
    def collate_fn(data:Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(1, goldmaxlen - state.gold_tensor.size(1))], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(1, inpmaxlen - state.inp_tensor.size(1))], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split:str=None, batsize:int=5):
        if split is None:   # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize, split=split)
            return ret
        else:
            assert(split in self.data.keys())
            dl = DataLoader(self.get_split(split), batch_size=batsize, shuffle=split=="train",
             collate_fn=type(self).collate_fn)
            return dl
Пример #19
0
class GeoDataset(object):
    def __init__(self,
                 p="../../datasets/geo880dong/",
                 sentence_encoder: SequenceEncoder = None,
                 min_freq: int = 2,
                 cvfolds=None,
                 testfold=None,
                 reorder_random=False,
                 **kw):
        super(GeoDataset, self).__init__(**kw)
        self.cvfolds, self.testfold = cvfolds, testfold
        self.reorder_random = reorder_random
        self._initialize(p, sentence_encoder, min_freq)

    def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int):
        self.data = {}
        self.sentence_encoder = sentence_encoder
        trainlines = [
            x.strip()
            for x in open(os.path.join(p, "train.txt"), "r").readlines()
        ]
        testlines = [
            x.strip()
            for x in open(os.path.join(p, "test.txt"), "r").readlines()
        ]
        if self.cvfolds is None:
            splits = ["train"] * len(trainlines) + ["test"] * len(testlines)
        else:
            cvsplit_len = len(trainlines) / self.cvfolds
            splits = []
            for i in range(0, self.cvfolds):
                splits += [i] * round(cvsplit_len * (i + 1) - len(splits))
            random.shuffle(splits)
            splits = [
                "valid" if x == self.testfold else "train" for x in splits
            ]
            splits = splits + ["test"] * len(testlines)
        questions, queries = zip(*[x.split("\t") for x in trainlines])
        testqs, testxs = zip(*[x.split("\t") for x in testlines])
        questions += testqs
        queries += testxs

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            basic_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question,
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        # for word, wordid in self.sentence_encoder.vocab.D.items():
        #     self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        token_specs = self.build_token_specs(queries)
        self.token_specs = token_specs

        self.build_data(questions, queries, splits)

    def build_token_specs(self, outputs: Iterable[str]):
        token_specs = dict()

        def walk_the_tree(t, _ts):
            l = t.label()
            if l not in _ts:
                _ts[l] = [np.infty, -np.infty]
            minc, maxc = _ts[l]
            _ts[l] = [min(minc, len(t)), max(maxc, len(t))]
            for c in t:
                walk_the_tree(c, _ts)

        for out in outputs:
            out_tokens = self.query_encoder.convert(out,
                                                    return_what="tokens")[0]
            assert (out_tokens[-1] == "@END@")
            out_tokens = out_tokens[:-1]
            out_str = " ".join(out_tokens)
            tree = lisp_to_tree(out_str)
            walk_the_tree(tree, token_specs)

        token_specs["and"][1] = np.infty

        return token_specs

    def build_data(self,
                   inputs: Iterable[str],
                   outputs: Iterable[str],
                   splits: Iterable[str],
                   unktokens: Set[str] = None):
        gold_map = None
        maxlen_in, maxlen_out = 0, 0
        if unktokens is not None:
            gold_map = torch.arange(0,
                                    self.query_encoder.vocab.number_of_ids())
            for rare_token in unktokens:
                gold_map[self.query_encoder.vocab[rare_token]] = \
                    self.query_encoder.vocab[self.query_encoder.vocab.unktoken]
        for inp, out, split in zip(inputs, outputs, splits):

            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(out)
            assert (gold_tree is not None)
            out_tensor, out_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")
            if gold_map is not None:
                out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [gold_tree],
                                     inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens],
                                     self.sentence_encoder.vocab,
                                     self.query_encoder.vocab,
                                     token_specs=self.token_specs)
            if split == "train" and self.reorder_random is True:
                gold_tree_ = tensor2tree(out_tensor, self.query_encoder.vocab)
                random_gold_tree = random.choice(
                    get_tree_permutations(gold_tree_, orderless={"and"}))
                out_ = tree_to_lisp(random_gold_tree)
                out_tensor_, out_tokens_ = self.query_encoder.convert(
                    out_, return_what="tensor,tokens")
                if gold_map is not None:
                    out_tensor_ = gold_map[out_tensor_]
                state.gold_tensor = out_tensor_[None]

            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, len(out_tensor))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out

    def get_split(self, split: str):
        data = []
        for split_e in split.split("+"):
            data += self.data[split_e]
        return DatasetSplitProxy(data)

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5, shuffle=None):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize,
                                             split=split,
                                             shuffle=shuffle)
            return ret
        else:
            # assert(split in self.data.keys())
            shuffle = shuffle if shuffle is not None else split in (
                "train", "train+valid")
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=shuffle,
                            collate_fn=type(self).collate_fn)
            return dl
Пример #20
0
class GeoDataset(object):
    def __init__(self,
                 p="../../datasets/geo880_multiling/geoquery/",
                 train_lang="en",
                 test_lang=None,
                 bert_tokenizer=None,
                 min_freq: int = 2,
                 cvfolds=None,
                 testfold=None,
                 **kw):
        super(GeoDataset, self).__init__(**kw)
        self.train_lang = train_lang
        self.test_lang = test_lang if test_lang is not None else train_lang
        self.cvfolds, self.testfold = cvfolds, testfold
        self._initialize(p, bert_tokenizer, min_freq)

    def _initialize(self, p, bert_tokenizer, min_freq: int):
        self.data = {}
        self.bert_vocab = Vocab()
        self.bert_vocab.set_dict(bert_tokenizer.vocab)
        self.sentence_encoder = SequenceEncoder(
            lambda x: bert_tokenizer.tokenize(f"[CLS] {x} [SEP]"),
            vocab=self.bert_vocab)
        trainlines = [
            x for x in ujson.load(
                open(os.path.join(p, f"geo-{self.train_lang}.json"), "r"))
        ]
        testlines = [
            x for x in ujson.load(
                open(os.path.join(p, f"geo-{self.train_lang}.json"), "r"))
        ]
        trainlines = [x for x in trainlines if x["split"] == "train"]
        testlines = [x for x in testlines if x["split"] == "test"]
        if self.cvfolds is None:
            splits = ["train"] * len(trainlines) + ["test"] * len(testlines)
        else:
            cvsplit_len = len(trainlines) / self.cvfolds
            splits = []
            for i in range(0, self.cvfolds):
                splits += [i] * round(cvsplit_len * (i + 1) - len(splits))
            random.shuffle(splits)
            splits = [
                "valid" if x == self.testfold else "train" for x in splits
            ]
            splits = splits + ["test"] * len(testlines)
        questions = [x["nl"] for x in trainlines]
        queries = [x["mrl"] for x in trainlines]
        xquestions = [x["nl"] for x in testlines]
        xqueries = [x["mrl"] for x in testlines]
        questions += xquestions
        queries += xqueries

        # initialize output vocabulary
        outvocab = Vocab()
        for token, bertid in self.bert_vocab.D.items():
            outvocab.add_token(token, seen=False)

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            basic_query_tokenizer, strtok=bert_tokenizer),
                                             vocab=outvocab,
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        keeptokens = set(self.bert_vocab.D.keys())
        self.query_encoder.finalize_vocab(min_freq=min_freq,
                                          keep_tokens=keeptokens)

        token_specs = self.build_token_specs(queries)
        self.token_specs = token_specs

        self.build_data(questions, queries, splits)

    def build_token_specs(self, outputs: Iterable[str]):
        token_specs = dict()

        def walk_the_tree(t, _ts):
            l = t.label()
            if l not in _ts:
                _ts[l] = [np.infty, -np.infty]
            minc, maxc = _ts[l]
            _ts[l] = [min(minc, len(t)), max(maxc, len(t))]
            for c in t:
                walk_the_tree(c, _ts)

        for out in outputs:
            out_tokens = self.query_encoder.convert(out,
                                                    return_what="tokens")[0]
            assert (out_tokens[-1] == "@END@")
            out_tokens = out_tokens[:-1]
            out_str = " ".join(out_tokens)
            tree = lisp_to_tree(out_str)
            walk_the_tree(tree, token_specs)

        # token_specs["and"][1] = np.infty

        return token_specs

    def build_data(self, inputs: Iterable[str], outputs: Iterable[str],
                   splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        for inp, out, split in zip(inputs, outputs, splits):
            # tokenize both input and output
            inp_tokens = self.sentence_encoder.convert(inp,
                                                       return_what="tokens")[0]
            out_tokens = self.query_encoder.convert(out,
                                                    return_what="tokens")[0]
            # get gold tree
            gold_tree = lisp_to_tree(" ".join(out_tokens[:-1]))
            assert (gold_tree is not None)
            # replace words in output that can't be copied from given input to UNK tokens
            unktoken = self.query_encoder.vocab.unktoken
            inp_tokens_ = set(inp_tokens)
            out_tokens = [
                out_token if out_token in inp_tokens_ or
                (out_token in self.query_encoder.vocab
                 and not out_token in self.query_encoder.vocab.rare_tokens)
                else unktoken for out_token in out_tokens
            ]
            # convert token sequences to ids
            inp_tensor = self.sentence_encoder.convert(inp_tokens,
                                                       return_what="tensor")[0]
            out_tensor = self.query_encoder.convert(out_tokens,
                                                    return_what="tensor")[0]

            state = TreeDecoderState([inp], [gold_tree],
                                     inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens],
                                     self.sentence_encoder.vocab,
                                     self.query_encoder.vocab,
                                     token_specs=self.token_specs)

            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, len(out_tensor))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out

    def get_split(self, split: str):
        data = []
        for split_e in split.split("+"):
            data += self.data[split_e]
        return DatasetSplitProxy(data)

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5, shuffle=None):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize,
                                             split=split,
                                             shuffle=shuffle)
            return ret
        else:
            # assert(split in self.data.keys())
            shuffle = shuffle if shuffle is not None else split in (
                "train", "train+valid")
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=shuffle,
                            collate_fn=type(self).collate_fn)
            return dl
def load_ds(traindomains=("restaurants", ),
            testdomain="housing",
            min_freq=1,
            mincoverage=1,
            top_k=np.infty,
            nl_mode="bert-base-uncased",
            fullsimplify=False,
            add_domain_start=True,
            useall=False):
    def tokenize_and_add_start(t, _domain):
        tokens = tree_to_lisp_tokens(t)
        starttok = f"@START/{_domain}@" if add_domain_start else "@START@"
        tokens = [starttok] + tokens
        return tokens

    allex = []
    for traindomain in traindomains:
        ds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True, restore_reverse=DATA_RESTORE_REVERSE, validfrac=.10)\
            .load(domain=traindomain)
        allex += ds[(None, None, lambda x: x in
                     ("train", "valid"))].map(lambda x: (x[0], x[1], x[
                         2], traindomain)).examples  # don't use test examples

    testds = OvernightDatasetLoader(simplify_mode="light" if not fullsimplify else "full", simplify_blocks=True, restore_reverse=DATA_RESTORE_REVERSE)\
        .load(domain=testdomain)
    if useall:
        print("using all training examples")
        sortedexamples = testds[(None, None, "train")].examples
    else:
        sortedexamples = get_maximum_spanning_examples(
            testds[(None, None, "train")].examples,
            mincoverage=mincoverage,
            loadedex=[e for e in allex if e[2] == "train"])

    allex += testds[(
        None, None,
        "valid")].map(lambda x: (x[0], x[1], "ftvalid", testdomain)).examples
    allex += testds[(
        None, None,
        "test")].map(lambda x: (x[0], x[1], x[2], testdomain)).examples
    allex += [(ex[0], ex[1], "fttrain", testdomain) for ex in sortedexamples]

    _ds = Dataset(allex)
    ds = _ds.map(lambda x:
                 (x[0], tokenize_and_add_start(x[1], x[3]), x[2], x[3]))

    et = get_lf_abstract_transform(ds[lambda x: x[3] != testdomain].examples)
    ds = ds.map(lambda x: (x[0], et(x[1]), x[1], x[2], x[3]))

    seqenc_vocab = Vocab(padid=0, startid=1, endid=2, unkid=UNKID)
    absseqenc_vocab = Vocab(padid=0, startid=1, endid=2, unkid=UNKID)
    absseqenc = SequenceEncoder(vocab=seqenc_vocab,
                                tokenizer=lambda x: x,
                                add_start_token=False,
                                add_end_token=True)
    fullseqenc = SequenceEncoder(vocab=absseqenc_vocab,
                                 tokenizer=lambda x: x,
                                 add_start_token=False,
                                 add_end_token=True)
    for example in ds.examples:
        absseqenc.inc_build_vocab(example[1],
                                  seen=example[3] in ("train", "fttrain"))
        fullseqenc.inc_build_vocab(example[2],
                                   seen=example[3] in ("train", "fttrain"))
    absseqenc.finalize_vocab(min_freq=min_freq, top_k=top_k)
    fullseqenc.finalize_vocab(min_freq=min_freq, top_k=top_k)

    nl_tokenizer = AutoTokenizer.from_pretrained(nl_mode)

    def tokenize(x):
        ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0],
               absseqenc.convert(x[1], return_what="tensor"),
               fullseqenc.convert(x[2], return_what="tensor"), x[3], x[0],
               x[1], x[4])
        return ret
    tds, ftds, vds, fvds, xds = ds[(None, None, None, "train", None)].map(tokenize), \
                          ds[(None, None, None, "fttrain", None)].map(tokenize), \
                          ds[(None, None, None, "valid", None)].map(tokenize), \
                          ds[(None, None, None, "ftvalid", None)].map(tokenize), \
                          ds[(None, None, None, "test", None)].map(tokenize)
    return tds, ftds, vds, fvds, xds, nl_tokenizer, fullseqenc, absseqenc
Пример #22
0
    def test_beam_search_stored_probs(self):
        with torch.no_grad():
            texts = ["a b"] * 2
            from parseq.vocab import SequenceEncoder
            se = SequenceEncoder(tokenizer=lambda x: x.split())
            for t in texts:
                se.inc_build_vocab(t)
            se.finalize_vocab()
            x = BasicDecoderState(texts, texts, se, se)
            x.start_decoding()

            class Model(TransitionModel):
                transition_tensor = torch.tensor([[0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .51, .48],
                                                  [0, 0, 0, 0.01, .01, .98]])

                def forward(self, x: BasicDecoderState):
                    prev = x.prev_actions
                    outprobs = self.transition_tensor[prev]
                    outprobs = torch.log(outprobs)
                    outprobs -= 0.01 * torch.rand_like(outprobs)
                    return outprobs, x

            model = Model()

            beamsize = 3
            maxtime = 10
            beam_xs = [
                x.make_copy(detach=False, deep=True) for _ in range(beamsize)
            ]
            beam_states = BeamState(beam_xs)

            print(len(beam_xs))
            print(len(beam_states))

            bt = BeamTransition(model, beamsize, maxtime=maxtime)
            i = 0
            _, _, y, _ = bt(x, i)
            i += 1
            _, _, y, _ = bt(y, i)

            all_terminated = False
            while not all_terminated:
                start_time = time.time()
                _, _, y, all_terminated = bt(y, i)
                i += 1
                # print(i)
                end_time = time.time()
                print(f"{i}: {end_time - start_time}")

            print(y)
            print(y.bstates.get(0).followed_actions)

            best_actions = y.bstates.get(0).followed_actions
            best_actionprobs = y.actionprobs.get(0)
            for i in range(len(best_actions)):
                print(i)
                i_prob = 0
                for j in range(len(best_actions[i])):
                    action_id = best_actions[i, j]
                    action_prob = best_actionprobs[i, j, action_id]
                    i_prob += action_prob
                print(i_prob)
                print(y.bscores[i, 0])
                self.assertTrue(torch.allclose(i_prob, y.bscores[i, 0]))
Пример #23
0
def load_ds(domain="restaurants",
            nl_mode="bert-base-uncased",
            trainonvalid=False,
            noreorder=False,
            numbered=False):
    """
    Creates a dataset of examples which have
    * NL question and tensor
    * original FL tree
    * reduced FL tree with slots (this is randomly generated)
    * tensor corresponding to reduced FL tree with slots
    * mask specifying which elements in reduced FL tree are terminated
    * 2D gold that specifies whether a token/action is in gold for every position (compatibility with MML!)
    """
    # orderless = {"op:and", "SW:concat"}     # only use in eval!!
    orderless = ORDERLESS

    ds = OvernightDatasetLoader(simplify_mode="none").load(
        domain=domain, trainonvalid=trainonvalid)
    # ds contains 3-tuples of (input, output tree, split name)

    if not noreorder:
        ds = ds.map(lambda x:
                    (x[0], reorder_tree(x[1], orderless=orderless), x[2]))
    ds = ds.map(lambda x: (x[0], tree_to_seq(x[1]), x[2]))

    if numbered:
        ds = ds.map(lambda x: (x[0], make_numbered_tokens(x[1]), x[2]))

    vocab = Vocab(padid=0, startid=2, endid=3, unkid=1)
    vocab.add_token("@BOS@", seen=np.infty)
    vocab.add_token("@EOS@", seen=np.infty)
    vocab.add_token("@STOP@", seen=np.infty)

    nl_tokenizer = BertTokenizer.from_pretrained(nl_mode)

    tds, vds, xds = ds[lambda x: x[2] == "train"], \
                    ds[lambda x: x[2] == "valid"], \
                    ds[lambda x: x[2] == "test"]

    seqenc = SequenceEncoder(vocab=vocab,
                             tokenizer=lambda x: x,
                             add_start_token=False,
                             add_end_token=False)
    for example in tds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=True)
    for example in vds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=False)
    for example in xds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=False)
    seqenc.finalize_vocab(min_freq=0)

    def mapper(x):
        seq = seqenc.convert(x[1], return_what="tensor")
        ret = (nl_tokenizer.encode(x[0], return_tensors="pt")[0], seq)
        return ret

    tds_seq = tds.map(mapper)
    vds_seq = vds.map(mapper)
    xds_seq = xds.map(mapper)
    return tds_seq, vds_seq, xds_seq, nl_tokenizer, seqenc, orderless
Пример #24
0
class GeoQueryDataset(object):
    def __init__(self,
                 p="../../datasets/geoquery/",
                 sentence_encoder: SequenceEncoder = None,
                 min_freq: int = 2,
                 **kw):
        super(GeoQueryDataset, self).__init__(**kw)
        self.data = {}
        self.sentence_encoder = sentence_encoder
        questions = [
            x.strip()
            for x in open(os.path.join(p, "questions.txt"), "r").readlines()
        ]
        queries = [
            x.strip()
            for x in open(os.path.join(p, "queries.funql"), "r").readlines()
        ]
        trainidxs = set([
            int(x.strip()) for x in open(os.path.join(p, "train_indexes.txt"),
                                         "r").readlines()
        ])
        testidxs = set([
            int(x.strip()) for x in open(os.path.join(p, "test_indexes.txt"),
                                         "r").readlines()
        ])
        splits = [None] * len(questions)
        for trainidx in trainidxs:
            splits[trainidx] = "train"
        for testidx in testidxs:
            splits[testidx] = "test"
        if any([split == None for split in splits]):
            print(
                f"{len([split for split in splits if split == None])} examples not assigned to any split"
            )

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            basic_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question,
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        self.sentence_encoder.finalize_vocab(min_freq=min_freq)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        self.build_data(questions, queries, splits)

    def build_data(self, inputs: Iterable[str], outputs: Iterable[str],
                   splits: Iterable[str]):
        for inp, out, split in zip(inputs, outputs, splits):
            state = BasicDecoderState([inp], [out], self.sentence_encoder,
                                      self.query_encoder)
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)

    def get_split(self, split: str):
        return DatasetSplitProxy(self.data[split])

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize, split=split)
            return ret
        else:
            assert (split in self.data.keys())
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=split == "train",
                            collate_fn=GeoQueryDataset.collate_fn)
            return dl
Пример #25
0
class OvernightDataset(object):
    def __init__(self,
                 p="../../datasets/overnightData/",
                 pcache="../../datasets/overnightCache/",
                 domain: str = "restaurants",
                 sentence_encoder: SequenceEncoder = None,
                 usecache=True,
                 min_freq: int = 2,
                 **kw):
        super(OvernightDataset, self).__init__(**kw)
        self._simplify_filters = True  # if True, filter expressions are converted to orderless and-expressions
        self._pcache = pcache if usecache else None
        self._domain = domain
        self._usecache = usecache
        self._initialize(p, domain, sentence_encoder, min_freq)

    def lines_to_examples(self, lines: List[str]):
        maxsize_before = 0
        avgsize_before = []
        maxsize_after = 0
        avgsize_after = []
        afterstring = set()

        def simplify_tree(t: Tree):
            if t.label() == "call":
                assert (len(t[0]) == 0)
                # if not t[0].label().startswith("SW."):
                #     print(t)
                # assert(t[0].label().startswith("SW."))
                t.set_label(t[0].label())
                del t[0]
            elif t.label() == "string":
                afterstring.update(set([tc.label() for tc in t]))
                assert (len(t) == 1)
                assert (len(t[0]) == 0)
                t.set_label(f"arg:{t[0].label()}")
                del t[0]
            if t.label().startswith(
                    "edu.stanford.nlp.sempre.overnight.SimpleWorld."):
                t.set_label("SW:" + t.label(
                )[len("edu.stanford.nlp.sempre.overnight.SimpleWorld."):])
            if t.label() == "SW:getProperty":
                assert (len(t) == 2)
                ret = simplify_tree(t[1])
                ret.append(simplify_tree(t[0]))
                return ret
            elif t.label() == "SW:singleton":
                assert (len(t) == 1)
                assert (len(t[0]) == 0)
                return t[0]
            elif t.label() == "SW:ensureNumericProperty":
                assert (len(t) == 1)
                return simplify_tree(t[0])
            elif t.label() == "SW:ensureNumericEntity":
                assert (len(t) == 1)
                return simplify_tree(t[0])
            elif t.label() == "SW:aggregate":
                assert (len(t) == 2)
                ret = simplify_tree(t[0])
                assert (ret.label() in ["arg:avg", "arg:sum"])
                assert (len(ret) == 0)
                ret.set_label(f"agg:{ret.label()}")
                ret.append(simplify_tree(t[1]))
                return ret
            else:
                t[:] = [simplify_tree(tc) for tc in t]
                return t

        def simplify_further(t):
            """ simplifies filters and count expressions """
            # replace filters with ands
            if t.label() == "SW:filter" and self._simplify_filters is True:
                if len(t) not in (2, 4):
                    raise Exception(
                        f"filter expression should have 2 or 4 children, got {len(children)}"
                    )
                children = [simplify_further(tc) for tc in t]
                startset = children[0]
                if len(children) == 2:
                    condition = Tree("cond:has", [children[1]])
                elif len(children) == 4:
                    condition = Tree(f"cond:{children[2].label()}",
                                     [children[1], children[3]])
                conditions = [condition]
                if startset.label() == "op:and":
                    conditions = startset[:] + conditions
                else:
                    conditions = [startset] + conditions
                # check for same conditions:
                i = 0
                while i < len(conditions) - 1:
                    j = i + 1
                    while j < len(conditions):
                        if conditions[i] == conditions[j]:
                            print(f"SAME!: {conditions[i]}, {conditions[j]}")
                            del conditions[j]
                            j -= 1
                        j += 1
                    i += 1

                ret = Tree(f"op:and", conditions)
                return ret
            # replace countSuperlatives with specific ones
            elif t.label() == "SW:countSuperlative":
                assert (t[1].label() in ["arg:max", "arg:min"])
                t.set_label(f"SW:CNT-{t[1].label()}")
                del t[1]
                t[:] = [simplify_further(tc) for tc in t]
            elif t.label() == "SW:countComparative":
                assert (t[2].label() in [
                    "arg:<", "arg:<=", "arg:>", "arg:>=", "arg:=", "arg:!="
                ])
                t.set_label(f"SW:CNT-{t[2].label()}")
                del t[2]
                t[:] = [simplify_further(tc) for tc in t]
            else:
                t[:] = [simplify_further(tc) for tc in t]
            return t

        def simplify_furthermore(t):
            """ replace reverse rels"""
            if t.label() == "arg:!type":
                t.set_label("arg:~type")
                return t
            elif t.label() == "SW:reverse":
                assert (len(t) == 1)
                assert (t[0].label().startswith("arg:"))
                assert (len(t[0]) == 0)
                t.set_label(f"arg:~{t[0].label()[4:]}")
                del t[0]
                return t
            elif t.label().startswith("cond:arg:"):
                assert (len(t) == 2)
                head = t[0]
                head = simplify_furthermore(head)
                assert (head.label().startswith("arg:"))
                assert (len(head) == 0)
                headlabel = f"arg:~{head.label()[4:]}"
                headlabel = headlabel.replace("~~", "")
                head.set_label(headlabel)
                body = simplify_furthermore(t[1])
                if t.label()[len("cond:arg:"):] != "=":
                    body = Tree(t.label()[5:], [body])
                head.append(body)
                return head
            else:
                t[:] = [simplify_furthermore(tc) for tc in t]
                return t

        def simplify_final(t):
            assert (t.label() == "SW:listValue")
            assert (len(t) == 1)
            return t[0]

        ret = []
        ltp = None
        j = 0
        for i, line in enumerate(lines):
            z, ltp = lisp_to_pas(line, ltp)
            if z is not None:
                print(f"Example {j}:")
                ztree = pas_to_tree(z[1][2][1][0])
                maxsize_before = max(maxsize_before, tree_size(ztree))
                avgsize_before.append(tree_size(ztree))
                lf = simplify_tree(ztree)
                lf = simplify_further(lf)
                lf = simplify_furthermore(lf)
                lf = simplify_final(lf)
                question = z[1][0][1][0]
                assert (question[0] == '"' and question[-1] == '"')
                ret.append((question[1:-1], lf))
                print(ret[-1][0])
                print(ret[-1][1])
                ltp = None
                maxsize_after = max(maxsize_after, tree_size(lf))
                avgsize_after.append(tree_size(lf))

                print(pas_to_tree(z[1][2][1][0]))
                print()
                j += 1

        avgsize_before = sum(avgsize_before) / len(avgsize_before)
        avgsize_after = sum(avgsize_after) / len(avgsize_after)

        print("Simplification results ({j} examples):")
        print(f"\t Max, Avg size before: {maxsize_before}, {avgsize_before}")
        print(f"\t Max, Avg size after: {maxsize_after}, {avgsize_after}")

        return ret

    def _load_cached(self):
        train_cached = ujson.load(
            open(os.path.join(self._pcache, f"{self._domain}.train.json"),
                 "r"))
        trainexamples = [(x, Tree.fromstring(y)) for x, y in train_cached]
        test_cached = ujson.load(
            open(os.path.join(self._pcache, f"{self._domain}.test.json"), "r"))
        testexamples = [(x, Tree.fromstring(y)) for x, y in test_cached]
        print("loaded from cache")
        return trainexamples, testexamples

    def _cache(self, trainexamples: List[Tuple[str, Tree]],
               testexamples: List[Tuple[str, Tree]]):
        train_cached, test_cached = None, None
        if os.path.exists(
                os.path.join(self._pcache, f"{self._domain}.train.json")):
            try:
                train_cached = ujson.load(
                    open(
                        os.path.join(self._pcache,
                                     f"{self._domain}.train.json"), "r"))
                test_cached = ujson.load(
                    open(
                        os.path.join(self._pcache,
                                     f"{self._domain}.test.json"), "r"))
            except (IOError, ValueError) as e:
                pass
        trainexamples = [(x, str(y)) for x, y in trainexamples]
        testexamples = [(x, str(y)) for x, y in testexamples]

        if train_cached != trainexamples:
            with open(os.path.join(self._pcache, f"{self._domain}.train.json"),
                      "w") as f:
                ujson.dump(trainexamples, f, indent=4, sort_keys=True)
        if test_cached != testexamples:
            with open(os.path.join(self._pcache, f"{self._domain}.test.json"),
                      "w") as f:
                ujson.dump(testexamples, f, indent=4, sort_keys=True)
        print("saved in cache")

    def _initialize(self, p, domain, sentence_encoder: SequenceEncoder,
                    min_freq: int):
        self.data = {}
        self.sentence_encoder = sentence_encoder

        trainexamples, testexamples = None, None
        if self._usecache:
            try:
                trainexamples, testexamples = self._load_cached()
            except (IOError, ValueError) as e:
                pass

        if trainexamples is None:

            trainlines = [
                x.strip() for x in open(
                    os.path.join(p, f"{domain}.paraphrases.train.examples"),
                    "r").readlines()
            ]
            testlines = [
                x.strip() for x in open(
                    os.path.join(p, f"{domain}.paraphrases.test.examples"),
                    "r").readlines()
            ]

            trainexamples = self.lines_to_examples(trainlines)
            testexamples = self.lines_to_examples(testlines)

            if self._usecache:
                self._cache(trainexamples, testexamples)

        questions, queries = tuple(zip(*(trainexamples + testexamples)))
        trainlen = int(round(0.8 * len(trainexamples)))
        validlen = int(round(0.2 * len(trainexamples)))
        splits = ["train"] * trainlen + ["valid"] * validlen
        # random.seed(1223)
        random.shuffle(splits)
        assert (len(splits) == len(trainexamples))
        splits = splits + ["test"] * len(testexamples)

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            tree_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            self.sentence_encoder.inc_build_vocab(question,
                                                  seen=split == "train")
            self.query_encoder.inc_build_vocab(query, seen=split == "train")
        for word, wordid in self.sentence_encoder.vocab.D.items():
            self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq)
        self.query_encoder.finalize_vocab(min_freq=min_freq)

        self.build_data(questions, queries, splits)

    def build_data(self, inputs: Iterable[str], outputs: Iterable[str],
                   splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        eid = 0
        for inp, out, split in zip(inputs, outputs, splits):
            state = TreeDecoderState([inp], [out], self.sentence_encoder,
                                     self.query_encoder)
            state.eids = np.asarray([eid], dtype="int64")
            maxlen_in, maxlen_out = max(maxlen_in,
                                        len(state.inp_tokens[0])), max(
                                            maxlen_out,
                                            len(state.gold_tokens[0]))
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            eid += 1
        self.maxlen_input, self.maxlen_output = maxlen_in, maxlen_out

    def get_split(self, split: str):
        splits = split.split("+")
        data = []
        for split in splits:
            data += self.data[split]
        return DatasetSplitProxy(data)

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize, split=split)
            return ret
        else:
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=split in ("train", "train+valid"),
                            collate_fn=OvernightDataset.collate_fn)
            return dl
Пример #26
0
class GeoQueryDatasetFunQL(object):
    def __init__(self,
                 p="../../datasets/geoquery/",
                 sentence_encoder: SequenceEncoder = None,
                 min_freq: int = 2,
                 **kw):
        super(GeoQueryDatasetFunQL, self).__init__(**kw)
        self._initialize(p, sentence_encoder, min_freq)

    def _initialize(self, p, sentence_encoder: SequenceEncoder, min_freq: int):
        self.data = {}
        self.sentence_encoder = sentence_encoder
        questions = [
            x.strip()
            for x in open(os.path.join(p, "questions.txt"), "r").readlines()
        ]
        queries = [
            x.strip()
            for x in open(os.path.join(p, "queries.funql"), "r").readlines()
        ]
        trainidxs = set([
            int(x.strip()) for x in open(os.path.join(p, "train_indexes.txt"),
                                         "r").readlines()
        ])
        testidxs = set([
            int(x.strip()) for x in open(os.path.join(p, "test_indexes.txt"),
                                         "r").readlines()
        ])
        splits = [None] * len(questions)
        for trainidx in trainidxs:
            splits[trainidx] = "train"
        for testidx in testidxs:
            splits[testidx] = "test"
        if any([split == None for split in splits]):
            print(
                f"{len([split for split in splits if split == None])} examples not assigned to any split"
            )

        self.query_encoder = SequenceEncoder(tokenizer=partial(
            basic_query_tokenizer, strtok=sentence_encoder.tokenizer),
                                             add_end_token=True)

        # build vocabularies
        unktokens = set()
        for i, (question, query,
                split) in enumerate(zip(questions, queries, splits)):
            question_tokens = self.sentence_encoder.inc_build_vocab(
                question, seen=split == "train")
            query_tokens = self.query_encoder.inc_build_vocab(
                query, seen=split == "train")
            unktokens |= set(query_tokens) - set(question_tokens)
        for word in self.sentence_encoder.vocab.counts.keys():
            self.query_encoder.vocab.add_token(word, seen=False)
        self.sentence_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True)
        self.query_encoder.finalize_vocab(min_freq=min_freq, keep_rare=True)
        unktokens = unktokens & self.query_encoder.vocab.rare_tokens

        self.build_data(questions, queries, splits, unktokens=unktokens)

    def build_data(self,
                   inputs: Iterable[str],
                   outputs: Iterable[str],
                   splits: Iterable[str],
                   unktokens: Set[str] = None):
        if unktokens is not None:
            gold_map = torch.arange(
                0, self.query_encoder.vocab.number_of_ids(last_nonrare=False))
            for rare_token in unktokens:
                gold_map[self.query_encoder.vocab[rare_token]] = \
                    self.query_encoder.vocab[self.query_encoder.vocab.unktoken]
        for inp, out, split in zip(inputs, outputs, splits):

            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = prolog_to_tree(out)
            assert (gold_tree is not None)
            out_tensor, out_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")
            if gold_map is not None:
                out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens], self.sentence_encoder.vocab,
                                     self.query_encoder.vocab)
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)

    def get_split(self, split: str):
        return DatasetSplitProxy(self.data[split])

    @staticmethod
    def collate_fn(data: Iterable):
        goldmaxlen = 0
        inpmaxlen = 0
        data = [state.make_copy(detach=True, deep=True) for state in data]
        for state in data:
            goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1))
            inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1))
        for state in data:
            state.gold_tensor = torch.cat([
                state.gold_tensor,
                state.gold_tensor.new_zeros(
                    1, goldmaxlen - state.gold_tensor.size(1))
            ], 1)
            state.inp_tensor = torch.cat([
                state.inp_tensor,
                state.inp_tensor.new_zeros(
                    1, inpmaxlen - state.inp_tensor.size(1))
            ], 1)
        ret = data[0].merge(data)
        return ret

    def dataloader(self, split: str = None, batsize: int = 5):
        if split is None:  # return all splits
            ret = {}
            for split in self.data.keys():
                ret[split] = self.dataloader(batsize=batsize, split=split)
            return ret
        else:
            assert (split in self.data.keys())
            dl = DataLoader(self.get_split(split),
                            batch_size=batsize,
                            shuffle=split == "train",
                            collate_fn=GeoQueryDatasetFunQL.collate_fn)
            return dl