예제 #1
0
    def build_data(self, inputs:Iterable[str], outputs:Iterable[str], splits:Iterable[str], unktokens:Set[str]=None):
        gold_map = None
        maxlen_in, maxlen_out = 0, 0
        maxlins = 0
        numlins_counts = [0] * (self.max_lins_allowed + 1)
        if unktokens is not None:
            gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids(last_nonrare=False))
            for rare_token in unktokens:
                gold_map[self.query_encoder.vocab[rare_token]] = \
                    self.query_encoder.vocab[self.query_encoder.vocab.unktoken]
        for inp, out, split in zip(inputs, outputs, splits):

            inp_tensor, inp_tokens = self.sentence_encoder.convert(inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(out)
            assert(gold_tree is not None)
            out_tensor, out_tokens = self.query_encoder.convert(out, return_what="tensor,tokens")

            if split == "train":
                gold_tree_ = tensor2tree(out_tensor, self.query_encoder.vocab)
                numlins = 0
                for gold_tree_reordered in get_tree_permutations(gold_tree_, orderless={"and", "or"}):
                    if numlins >= self.max_lins_allowed:
                        break
                    out_ = tree_to_lisp(gold_tree_reordered)
                    out_tensor_, out_tokens_ = self.query_encoder.convert(out_, return_what="tensor,tokens")
                    if gold_map is not None:
                        out_tensor = gold_map[out_tensor]

                    state = TreeDecoderState([inp], [gold_tree_reordered],
                                              inp_tensor[None, :], out_tensor_[None, :],
                                              [inp_tokens], [out_tokens_],
                                              self.sentence_encoder.vocab, self.query_encoder.vocab,
                                             token_specs=self.token_specs)
                    if split not in self.data:
                        self.data[split] = []
                    self.data[split].append(state)
                    numlins += 1
                numlins_counts[numlins] += 1
                maxlins = max(maxlins, numlins)
            else:
                if gold_map is not None:
                    out_tensor = gold_map[out_tensor]

                state = TreeDecoderState([inp], [gold_tree],
                                         inp_tensor[None, :], out_tensor[None, :],
                                         [inp_tokens], [out_tokens],
                                         self.sentence_encoder.vocab, self.query_encoder.vocab,
                                         token_specs=self.token_specs)
                if split not in self.data:
                    self.data[split] = []
                self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, len(out_tensor))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out
예제 #2
0
    def build_data(self,
                   inputs: Iterable[str],
                   outputs: Iterable[str],
                   splits: Iterable[str],
                   unktokens: Set[str] = None):
        gold_map = None
        maxlen_in, maxlen_out = 0, 0
        if unktokens is not None:
            gold_map = torch.arange(0,
                                    self.query_encoder.vocab.number_of_ids())
            for rare_token in unktokens:
                gold_map[self.query_encoder.vocab[rare_token]] = \
                    self.query_encoder.vocab[self.query_encoder.vocab.unktoken]
        for inp, out, split in zip(inputs, outputs, splits):

            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(out)
            assert (gold_tree is not None)
            out_tensor, out_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")
            if gold_map is not None:
                out_tensor = gold_map[out_tensor]

            state = TreeDecoderState([inp], [gold_tree],
                                     inp_tensor[None, :],
                                     out_tensor[None, :], [inp_tokens],
                                     [out_tokens],
                                     self.sentence_encoder.vocab,
                                     self.query_encoder.vocab,
                                     token_specs=self.token_specs)
            if split == "train" and self.reorder_random is True:
                gold_tree_ = tensor2tree(out_tensor, self.query_encoder.vocab)
                random_gold_tree = random.choice(
                    get_tree_permutations(gold_tree_, orderless={"and"}))
                out_ = tree_to_lisp(random_gold_tree)
                out_tensor_, out_tokens_ = self.query_encoder.convert(
                    out_, return_what="tensor,tokens")
                if gold_map is not None:
                    out_tensor_ = gold_map[out_tensor_]
                state.gold_tensor = out_tensor_[None]

            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, len(out_tensor))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out
예제 #3
0
    def test_forward(self,
                     x: torch.Tensor,
                     gold: torch.Tensor = None
                     ):  # --> implement how decoder operates end-to-end
        preds, prednll, maxmaxnll, stepsused = self.get_prediction(x)

        def tensor_to_trees(x, vocab: Vocab):
            xstrs = [
                vocab.tostr(x[i]).replace("@START@", "") for i in range(len(x))
            ]
            xstrs = [re.sub("::\d+", "", xstr) for xstr in xstrs]
            trees = []
            for xstr in xstrs:
                # drop everything after @END@, if present
                xstr = xstr.split("@END@")
                xstr = xstr[0]
                # add an opening parentheses if not there
                xstr = xstr.strip()
                if len(xstr) == 0 or xstr[0] != "(":
                    xstr = "(" + xstr
                # balance closing parentheses
                parenthese_imbalance = xstr.count("(") - xstr.count(")")
                xstr = xstr + ")" * max(0, parenthese_imbalance
                                        )  # append missing closing parentheses
                xstr = "(" * -min(
                    0, parenthese_imbalance
                ) + xstr  # prepend missing opening parentheses
                try:
                    tree = taglisp_to_tree(xstr)
                    if isinstance(
                            tree,
                            tuple) and len(tree) == 2 and tree[0] is None:
                        tree = None
                except Exception as e:
                    tree = None
                trees.append(tree)
            return trees

        # compute loss and metrics
        gold_trees = tensor_to_trees(gold, vocab=self.vocab)
        pred_trees = tensor_to_trees(preds, vocab=self.vocab)
        treeaccs = [
            float(
                are_equal_trees(gold_tree,
                                pred_tree,
                                orderless=ORDERLESS,
                                unktoken="@UNK@"))
            for gold_tree, pred_tree in zip(gold_trees, pred_trees)
        ]
        ret = {
            "treeacc": torch.tensor(treeaccs).to(x.device),
            "stepsused": stepsused
        }

        # compute bleu scores
        bleus = []
        lcsf1s = []
        for gold_tree, pred_tree in zip(gold_trees, pred_trees):
            if pred_tree is None or gold_tree is None:
                bleuscore = 0
                lcsf1 = 0
            else:
                gold_str = tree_to_lisp(gold_tree)
                pred_str = tree_to_lisp(pred_tree)
                bleuscore = sentence_bleu([gold_str.split(" ")],
                                          pred_str.split(" "))
                lcsn = lcs(gold_str, pred_str)
                lcsrec = lcsn / len(gold_str)
                lcsprec = lcsn / len(pred_str)
                lcsf1 = 2 * lcsrec * lcsprec / (lcsrec + lcsprec)
            bleus.append(bleuscore)
            lcsf1s.append(lcsf1)
        bleus = torch.tensor(bleus).to(x.device)
        ret["bleu"] = bleus
        ret["lcsf1"] = torch.tensor(lcsf1s).to(x.device)

        d, logits = self.train_forward(x, gold)
        nll, acc, elemacc = d["loss"], d["acc"], d["elemacc"]
        ret["nll"] = nll
        ret["acc"] = acc
        ret["elemacc"] = elemacc

        # d, logits = self.train_forward(x, preds[:, 1:])
        # decnll = d["loss"]
        # ret["decnll"] = decnll

        ret["decnll"] = prednll
        ret["maxmaxnll"] = maxmaxnll
        return ret, pred_trees
예제 #4
0
 def tree_to_lisptokens(self, x: Tree):
     xstr = tree_to_lisp(x)
     xstr = xstr.replace("(", " ( ").replace(")", " ) ")
     xstr = re.sub("\s+", " ", xstr)
     xstr = xstr.split(" ")
     return xstr