示例#1
0
 def test_unordered_twolevel(self):
     a = lisp_to_tree(
         "(and  (or (wife BO) (spouse BO))  (or (wife BR) (spouse BR))  ) ")
     b = lisp_to_tree(
         "(and (or (spouse BR) (wife BR))  (or (spouse BO) (wife BO))  )")
     print(are_equal_trees(a, b))
     self.assertTrue(are_equal_trees(a, b))
示例#2
0
 def test_it_unk(self):
     a = lisp_to_tree("(and (wife BO) (spouse BO))")
     b = lisp_to_tree("(and (wife BO) (spouse @UNK@))")
     print(are_equal_trees(a, b))
     print(are_equal_trees(b, b))
     self.assertFalse(are_equal_trees(a, b))
     self.assertFalse(are_equal_trees(b, b))
示例#3
0
    def test_forward(self,
                     x: torch.Tensor,
                     gold: torch.Tensor = None
                     ):  # --> implement how decoder operates end-to-end
        preds, stepsused = self.get_prediction(x)

        def tensor_to_trees(x, vocab: Vocab):
            xstrs = [
                vocab.tostr(x[i]).replace("@BOS@", "").replace("@EOS@", "")
                for i in range(len(x))
            ]
            xstrs = [re.sub("::\d+", "", xstr) for xstr in xstrs]
            trees = []
            for xstr in xstrs:
                # drop everything after @END@, if present
                xstr = xstr.split("@END@")
                xstr = xstr[0]
                # add an opening parentheses if not there
                xstr = xstr.strip()
                if len(xstr) == 0 or xstr[0] != "(":
                    xstr = "(" + xstr
                # balance closing parentheses
                parenthese_imbalance = xstr.count("(") - xstr.count(")")
                xstr = xstr + ")" * max(0, parenthese_imbalance
                                        )  # append missing closing parentheses
                xstr = "(" * -min(
                    0, parenthese_imbalance
                ) + xstr  # prepend missing opening parentheses
                try:
                    tree = lisp_to_tree(xstr)
                    if isinstance(
                            tree,
                            tuple) and len(tree) == 2 and tree[0] is None:
                        tree = None
                except Exception as e:
                    tree = None
                trees.append(tree)
            return trees

        # compute loss and metrics
        gold_trees = tensor_to_trees(gold, vocab=self.vocab)
        pred_trees = tensor_to_trees(preds, vocab=self.vocab)
        treeaccs = [
            float(
                are_equal_trees(gold_tree,
                                pred_tree,
                                orderless=ORDERLESS,
                                unktoken="@UNK@"))
            for gold_tree, pred_tree in zip(gold_trees, pred_trees)
        ]
        ret = {
            "treeacc": torch.tensor(treeaccs).to(x.device),
            "stepsused": stepsused
        }
        return ret, pred_trees
示例#4
0
 def are_equal_trees(self, x: Tree, y: Tree, use_terminator=False):
     if x is None or y is None:
         return False
     _x = self.normalize_entities_and_variables(x, generic=True)
     _y = self.normalize_entities_and_variables(y, generic=True)
     # print(_x)
     # print(_y)
     ret = are_equal_trees(_x,
                           _y,
                           orderless=self.orderless,
                           unktoken=self.outD[self.outD.unktoken],
                           use_terminator=use_terminator)
     if ret is False:
         return False
     _x = x
     _y = y
     _x = self.normalize_tree(x)[1]
     _y = self.normalize_tree(y)[1]
     _x = self.normalize_entities_and_variables(_x, generic=False)
     _y = self.normalize_entities_and_variables(_y, generic=False)
     # print(_x)
     # print(_y)
     ret = are_equal_trees(_x,
                           _y,
                           orderless=self.orderless,
                           unktoken=self.outD[self.outD.unktoken],
                           use_terminator=use_terminator)
     if ret is False:
         # return False
         # print(_x)
         # print(_y)
         # iterate over possible reassignments of one tree
         for reassigned_y in self.reassignments(_y):
             # reassigned_y = self.normalize_entities_and_variables(reassigned_y, generic=False)
             if are_equal_trees(reassigned_y,
                                _x,
                                orderless=self.orderless,
                                unktoken=self.outD[self.outD.unktoken],
                                use_terminator=use_terminator):
                 return True
         return False
     return ret
示例#5
0
文件: eval.py 项目: lukovnikov/parseq
 def compare(_gold_trees, _predactions):
     pred_trees = [
         self.tensor2tree(predactionse) for predactionse in _predactions
     ]
     ret = [
         float(
             are_equal_trees(gold_tree,
                             pred_tree,
                             orderless=self.orderless,
                             unktoken=self.unktoken))
         for gold_tree, pred_tree in zip(_gold_trees, pred_trees)
     ]
     return ret
示例#6
0
def try_tree_permutations():
    tree = Tree("x", [Tree("a", [Tree("1", []), Tree("2", []), Tree("3", [])]), Tree("b", [Tree("1", []), Tree("2", [])])])
    print(tree)
    print("")
    perms = []
    unique_perms = set()
    for tree_perm in get_tree_permutations(tree, orderless={"a", "x"}):
        print(tree_perm)
        assert(are_equal_trees(tree, tree_perm, orderless={"a", "x"}))
        unique_perms.add(str(tree_perm))
        perms.append(str(tree_perm))

    print(len(unique_perms), len(perms))
示例#7
0
    def test_it(self):
        a = lisp_to_tree("( and ( wife BO ) ( spouse BO ) )")
        b = lisp_to_tree("(and (spouse BO) (wife BO))")
        c = lisp_to_tree("(and (wife BO) (wife BO))")

        print(are_equal_trees(a, a))  # should be True
        print(are_equal_trees(a, b))  # should be True
        print(are_equal_trees(a, c))  # should be False
        self.assertTrue(are_equal_trees(a, a))
        self.assertTrue(are_equal_trees(a, b))
        self.assertFalse(are_equal_trees(a, c))
示例#8
0
    def build_data(self, examples: Iterable[dict], splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        for example, split in zip(examples, splits):
            inp, out = " ".join(example["sentence"]), " ".join(example["gold"])
            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(" ".join(example["gold"][:-1]))
            if not isinstance(gold_tree, Tree):
                assert (gold_tree is not None)
            gold_tensor, gold_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")

            candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], []
            candidate_align_entropies = []
            candidate_trees = []
            candidate_same = []
            for cand in example["candidates"]:
                cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]),
                                            None)
                if cand_tree is None:
                    cand_tree = Tree("@UNK@", [])
                assert (cand_tree is not None)
                cand_tensor, cand_tokens = self.query_encoder.convert(
                    " ".join(cand["tokens"]), return_what="tensor,tokens")
                candidate_tensors.append(cand_tensor)
                candidate_tokens.append(cand_tokens)
                candidate_align_tensors.append(torch.tensor(
                    cand["alignments"]))
                candidate_align_entropies.append(
                    torch.tensor(cand["align_entropies"]))
                candidate_trees.append(cand_tree)
                candidate_same.append(
                    are_equal_trees(cand_tree,
                                    gold_tree,
                                    orderless={"and", "or"},
                                    unktoken="@NOUNKTOKENHERE@"))

            candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0),
                                           0)
            candidate_align_tensor = torch.stack(
                q.pad_tensors(candidate_align_tensors, 0), 0)
            candidate_align_entropy = torch.stack(
                q.pad_tensors(candidate_align_entropies, 0), 0)
            candidate_same = torch.tensor(candidate_same)

            state = RankState(
                inp_tensor[None, :],
                gold_tensor[None, :],
                candidate_tensor[None, :, :],
                candidate_same[None, :],
                candidate_align_tensor[None, :],
                candidate_align_entropy[None, :],
                self.sentence_encoder.vocab,
                self.query_encoder.vocab,
            )
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, candidate_tensor.size(-1),
                             gold_tensor.size(-1))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out
示例#9
0
    def test_forward(self,
                     x: torch.Tensor,
                     gold: torch.Tensor = None
                     ):  # --> implement how decoder operates end-to-end
        preds, prednll, maxmaxnll, entropy, total, avgconf, sumnll, stepsused = self.get_prediction(
            x)

        def tensor_to_trees(x, vocab: Vocab):
            xstrs = [
                vocab.tostr(x[i]).replace("@START@", "") for i in range(len(x))
            ]
            xstrs = [re.sub("::\d+", "", xstr) for xstr in xstrs]
            trees = []
            for xstr in xstrs:
                # drop everything after @END@, if present
                xstr = xstr.split("@END@")
                xstr = xstr[0]
                # add an opening parentheses if not there
                xstr = xstr.strip()
                if len(xstr) == 0 or xstr[0] != "(":
                    xstr = "(" + xstr
                # balance closing parentheses
                parenthese_imbalance = xstr.count("(") - xstr.count(")")
                xstr = xstr + ")" * max(0, parenthese_imbalance
                                        )  # append missing closing parentheses
                xstr = "(" * -min(
                    0, parenthese_imbalance
                ) + xstr  # prepend missing opening parentheses
                try:
                    tree = taglisp_to_tree(xstr)
                    if isinstance(
                            tree,
                            tuple) and len(tree) == 2 and tree[0] is None:
                        tree = None
                except Exception as e:
                    tree = None
                trees.append(tree)
            return trees

        # compute loss and metrics
        gold_trees = tensor_to_trees(gold, vocab=self.vocab)
        pred_trees = tensor_to_trees(preds, vocab=self.vocab)
        treeaccs = [
            float(
                are_equal_trees(gold_tree,
                                pred_tree,
                                orderless=ORDERLESS,
                                unktoken="@UNK@"))
            for gold_tree, pred_tree in zip(gold_trees, pred_trees)
        ]
        ret = {
            "treeacc": torch.tensor(treeaccs).to(x.device),
            "stepsused": stepsused
        }

        if self.mcdropout > 0:
            probses = []
            preds = preds[:, 1:]
            self.train()
            for i in range(self.mcdropout):
                d, logits = self.train_forward(x, preds)
                probses.append(torch.softmax(logits, -1))
            self.eval()
            probses = sum(probses) / len(probses)
            probses = probses[:, :-1]
            probs = probses
            mask = preds > 0
            confs = torch.gather(probs, 2, preds[:, :, None])[:, :, 0]
            nlls = -torch.log(confs)

            avgconf = (confs + (1 - mask.float())).prod(-1)
            avgnll = (nlls * mask).sum(-1) / mask.float().sum(-1).clamp(1e-6)
            sumnll = (nlls * mask).sum(-1)
            maxnll, _ = (nlls + (1 - mask.float()) * -1e6).max(-1)
            entropy = (-torch.log(probs.clamp_min(1e-7)) * probs).sum(-1)
            entropy = (entropy *
                       mask).sum(-1) / mask.float().sum(-1).clamp(1e-6)
            ret["decnll"] = avgnll
            ret["sumnll"] = sumnll
            ret["maxmaxnll"] = maxnll
            ret["entropy"] = entropy
            ret["avgconf"] = avgconf
        else:
            ret["decnll"] = prednll
            ret["sumnll"] = sumnll
            ret["maxmaxnll"] = maxmaxnll
            ret["entropy"] = entropy
            ret["avgconf"] = avgconf
        return ret, pred_trees
    def forward(self,
                inpseqs: torch.Tensor = None,
                gold: torch.Tensor = None,
                mode: str = None,
                maxsteps: int = None,
                **kw):
        """

        """
        maxsteps = maxsteps if maxsteps is not None else self.maxsteps
        mode = mode if mode is not None else self.mode
        device = next(self.parameters()).device

        trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None)
        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [
                add_descendants_ancestors(goldtree) for goldtree in goldtrees
            ]
            for i in range(len(trees)):
                trees[i].align = goldtrees[i]

        i = 0

        if self.training:
            trees = [assign_gold_actions(tree, mode=mode) for tree in trees]
        # choices = [deepcopy(trees)]
        numsteps = [0 for _ in range(len(trees))]
        treesizes = [0 for _ in range(len(trees))]
        seqlens = [0 for _ in range(len(trees))]
        allmetrics = {}
        allmetrics["loss"] = []
        allmetrics["ce"] = []
        allmetrics["elemrecall"] = []
        allmetrics["allrecall"] = []
        allmetrics["anyrecall"] = []
        allmetrics["lowestentropyrecall"] = []
        examplemasks = []
        while not all([all_terminated(tree)
                       for tree in trees]) and i < maxsteps:
            # go from tree to tensors,
            seq = []
            openmask = []
            stepgold = []
            for j, tree in enumerate(trees):
                if not all_terminated(tree):
                    numsteps[j] += 1
                treesizes[j] = tree_size(tree)

                fltoks, openmask_e, stepgold_e = extract_info(tree)
                seqlens[j] = len(fltoks)
                seq_e = self.seqenc.convert(fltoks, return_what="tensor")
                seq.append(seq_e)
                openmask.append(torch.tensor(openmask_e))
                if self.training:
                    stepgold_e_tensor = torch.zeros(
                        seq_e.size(0), self.seqenc.vocab.number_of_ids())
                    for j, stepgold_e_i in enumerate(stepgold_e):
                        for golde in stepgold_e_i:
                            stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1
                    stepgold.append(stepgold_e_tensor)
            seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device)
            openmask = torch.stack(q.pad_tensors(openmask, 0, False),
                                   0).to(device)
            if self.training:
                stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0),
                                       0).to(device)

            if self.training:
                examplemask = (stepgold != 0).any(-1).any(-1)
                examplemasks.append(examplemask)

            #  feed to tagger,
            probs = self.tagger(seq, openmask=openmask, **context)

            if self.training:
                # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions)
                ce = self.ce(probs, stepgold, mask=openmask)
                elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall(
                    probs, stepgold, mask=openmask)
                allmetrics["loss"].append(ce)
                allmetrics["ce"].append(ce)
                allmetrics["elemrecall"].append(elemrecall)
                allmetrics["allrecall"].append(allrecall)
                allmetrics["anyrecall"].append(anyrecall)
                allmetrics["lowestentropyrecall"].append(lowestentropyrecall)

            #  get best predictions,
            _, best_actions = probs.max(-1)
            entropies = torch.softmax(probs, -1).clamp_min(1e-6)
            entropies = -(entropies * torch.log(entropies)).sum(-1)

            if self.training:
                newprobs = torch.softmax(probs,
                                         -1) * stepgold  # mask using gold
                newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min(
                    1e-6)  # renormalize
                uniform = stepgold
                uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6)
                newprobs = newprobs * (1 - q.v(
                    self.uniformfactor)) + uniform * q.v(self.uniformfactor)

                noprobsmask = (newprobs != 0).any(-1, keepdim=True).float()
                zeroprobs = torch.zeros_like(newprobs)
                zeroprobs[:, :, 0] = 1
                newprobs = newprobs * noprobsmask + (1 -
                                                     noprobsmask) * zeroprobs

                sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\
                    .sample().view(newprobs.size(0), newprobs.size(1))
                taken_actions = sampled_gold_actions
            else:
                taken_actions = best_actions

            #  attach chosen actions and entropies to existing trees,
            for tree, actions_e, entropies_e in zip(trees, taken_actions,
                                                    entropies):
                actions_e = list(actions_e.detach().cpu().numpy())
                actions_e = [self.seqenc.vocab(xe) for xe in actions_e]
                entropies_e = list(entropies_e.detach().cpu().numpy())
                self.attach_info_to_tree(tree,
                                         _chosen_action=actions_e,
                                         _entropy=entropies_e)
            # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e)
            #          for seqe, openmaske, actione, entropies_e
            #          in zip(list(seq), list(openmask), list(taken_actions), list(entropies))]

            #  and execute,
            trees_ = []
            for tree in trees:
                if tree_size(tree) < self.max_tree_size:
                    markmode = mode if not self.training else "train-" + mode
                    tree = mark_for_execution(tree,
                                              mode=markmode,
                                              entropylimit=self.entropylimit)
                    budget = [self.max_tree_size - tree_size(tree)]
                    if self.training:
                        tree, _ = convert_chosen_actions_from_str_to_node(tree)
                    tree = execute_chosen_actions(tree,
                                                  _budget=budget,
                                                  mode=mode)
                    if self.training:
                        tree = assign_gold_actions(tree, mode=mode)
                trees_.append(tree)

            trees = trees_
            i += 1
            #  then repeat until all terminated

        # after done decoding, if gold is given, run losses, else return just predictions

        ret = {}

        ret["seqlens"] = torch.tensor(seqlens).float()
        ret["treesizes"] = torch.tensor(treesizes).float()
        ret["numsteps"] = torch.tensor(numsteps).float()

        if self.training:
            assert (len(examplemasks) > 0)
            assert (len(allmetrics["loss"]) > 0)
            allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()}
            examplemasks = torch.stack(examplemasks, 1).float()
            _allmetrics = {}
            for k, v in allmetrics.items():
                _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum(
                    1).clamp_min(1e-6)
            allmetrics = _allmetrics
            ret.update(allmetrics)

        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [simplify_tree_for_eval(x) for x in goldtrees]
            predtrees = [simplify_tree_for_eval(x) for x in trees]
            ret["treeacc"] = [
                float(
                    are_equal_trees(gold_tree,
                                    pred_tree,
                                    orderless=ORDERLESS,
                                    unktoken="@UNK@"))
                for gold_tree, pred_tree in zip(goldtrees, predtrees)
            ]
            ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device)

        return ret, trees
示例#11
0
 def test_unordered_duplicate(self):
     a = lisp_to_tree("(and (wife BO) (wife BR) (wife BO))")
     b = lisp_to_tree("(and (wife BO) (wife BR) (wife BR))")
     print(are_equal_trees(a, b))
     self.assertFalse(are_equal_trees(a, b))
示例#12
0
 def test_ordered(self):
     a = lisp_to_tree("(nand (wife BO) (spouse BO))")
     b = lisp_to_tree("(nand (wife BO) (spouse BO) (child BO))")
     print(are_equal_trees(a, b))
     self.assertFalse(are_equal_trees(a, b))
示例#13
0
    def forward(self,
                inpseqs: torch.Tensor = None,
                y_in: torch.Tensor = None,
                gold: torch.Tensor = None,
                mode: str = None,
                maxsteps: int = None,
                **kw):
        """

        """
        maxsteps = maxsteps if maxsteps is not None else self.maxsteps
        mode = mode if mode is not None else self.mode
        device = next(self.parameters()).device

        trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=y_in)

        i = 0
        while not all([all_terminated(tree)
                       for tree in trees]) and i < maxsteps:
            # go from tree to tensors,
            tensors = []
            masks = []
            for tree in trees:
                fltoks, openmask = extract_info(tree, nogold=True)
                seq = self.seqenc.convert(fltoks, return_what="tensor")
                tensors.append(seq)
                masks.append(torch.tensor(openmask))
            seq = torch.stack(q.pad_tensors(tensors, 0), 0).to(device)
            openmask = torch.stack(q.pad_tensors(masks, 0, False),
                                   0).to(device)

            #  feed to tagger,
            probs = self.tagger(seq, openmask=openmask, **context)

            #  get best predictions,
            _, best = probs.max(-1)
            entropies = torch.softmax(probs, -1).clamp_min(1e-6)
            entropies = -(entropies * torch.log(entropies)).sum(-1)

            #  convert to trees,
            trees = [
                tensors_to_tree(seqe,
                                openmask=openmaske,
                                actions=beste,
                                D=self.seqenc.vocab,
                                entropies=entropies_e)
                for seqe, openmaske, beste, entropies_e in zip(
                    list(seq), list(openmask), list(best), list(entropies))
            ]

            #  and execute,
            trees_ = []
            for tree in trees:
                if tree_size(tree) < self.max_tree_size:
                    tree = mark_for_execution(tree, mode=mode)
                    budget = [self.max_tree_size - tree_size(tree)]
                    tree = execute_chosen_actions(tree,
                                                  _budget=budget,
                                                  mode=mode)
                trees_.append(tree)

            trees = trees_
            i += 1
            #  then repeat until all terminated

        # after done decoding, if gold is given, run losses, else return just predictions

        ret = {}

        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [simplify_tree_for_eval(x) for x in goldtrees]
            predtrees = [simplify_tree_for_eval(x) for x in trees]
            ret["treeacc"] = [
                float(
                    are_equal_trees(gold_tree,
                                    pred_tree,
                                    orderless=ORDERLESS,
                                    unktoken="@UNK@"))
                for gold_tree, pred_tree in zip(goldtrees, predtrees)
            ]
            ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device)

        return ret, trees
示例#14
0
    def test_forward(self,
                     x: torch.Tensor,
                     gold: torch.Tensor = None
                     ):  # --> implement how decoder operates end-to-end
        preds, prednll, maxmaxnll, stepsused = self.get_prediction(x)

        def tensor_to_trees(x, vocab: Vocab):
            xstrs = [
                vocab.tostr(x[i]).replace("@START@", "") for i in range(len(x))
            ]
            xstrs = [re.sub("::\d+", "", xstr) for xstr in xstrs]
            trees = []
            for xstr in xstrs:
                # drop everything after @END@, if present
                xstr = xstr.split("@END@")
                xstr = xstr[0]
                # add an opening parentheses if not there
                xstr = xstr.strip()
                if len(xstr) == 0 or xstr[0] != "(":
                    xstr = "(" + xstr
                # balance closing parentheses
                parenthese_imbalance = xstr.count("(") - xstr.count(")")
                xstr = xstr + ")" * max(0, parenthese_imbalance
                                        )  # append missing closing parentheses
                xstr = "(" * -min(
                    0, parenthese_imbalance
                ) + xstr  # prepend missing opening parentheses
                try:
                    tree = taglisp_to_tree(xstr)
                    if isinstance(
                            tree,
                            tuple) and len(tree) == 2 and tree[0] is None:
                        tree = None
                except Exception as e:
                    tree = None
                trees.append(tree)
            return trees

        # compute loss and metrics
        gold_trees = tensor_to_trees(gold, vocab=self.vocab)
        pred_trees = tensor_to_trees(preds, vocab=self.vocab)
        treeaccs = [
            float(
                are_equal_trees(gold_tree,
                                pred_tree,
                                orderless=ORDERLESS,
                                unktoken="@UNK@"))
            for gold_tree, pred_tree in zip(gold_trees, pred_trees)
        ]
        ret = {
            "treeacc": torch.tensor(treeaccs).to(x.device),
            "stepsused": stepsused
        }

        # compute bleu scores
        bleus = []
        lcsf1s = []
        for gold_tree, pred_tree in zip(gold_trees, pred_trees):
            if pred_tree is None or gold_tree is None:
                bleuscore = 0
                lcsf1 = 0
            else:
                gold_str = tree_to_lisp(gold_tree)
                pred_str = tree_to_lisp(pred_tree)
                bleuscore = sentence_bleu([gold_str.split(" ")],
                                          pred_str.split(" "))
                lcsn = lcs(gold_str, pred_str)
                lcsrec = lcsn / len(gold_str)
                lcsprec = lcsn / len(pred_str)
                lcsf1 = 2 * lcsrec * lcsprec / (lcsrec + lcsprec)
            bleus.append(bleuscore)
            lcsf1s.append(lcsf1)
        bleus = torch.tensor(bleus).to(x.device)
        ret["bleu"] = bleus
        ret["lcsf1"] = torch.tensor(lcsf1s).to(x.device)

        d, logits = self.train_forward(x, gold)
        nll, acc, elemacc = d["loss"], d["acc"], d["elemacc"]
        ret["nll"] = nll
        ret["acc"] = acc
        ret["elemacc"] = elemacc

        # d, logits = self.train_forward(x, preds[:, 1:])
        # decnll = d["loss"]
        # ret["decnll"] = decnll

        ret["decnll"] = prednll
        ret["maxmaxnll"] = maxmaxnll
        return ret, pred_trees