def build_data(self, inputs:Iterable[str], outputs:Iterable[str], splits:Iterable[str], unktokens:Set[str]=None): gold_map = None maxlen_in, maxlen_out = 0, 0 maxlins = 0 numlins_counts = [0] * (self.max_lins_allowed + 1) if unktokens is not None: gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids(last_nonrare=False)) for rare_token in unktokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert(inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(out) assert(gold_tree is not None) out_tensor, out_tokens = self.query_encoder.convert(out, return_what="tensor,tokens") if split == "train": gold_tree_ = tensor2tree(out_tensor, self.query_encoder.vocab) numlins = 0 for gold_tree_reordered in get_tree_permutations(gold_tree_, orderless={"and", "or"}): if numlins >= self.max_lins_allowed: break out_ = tree_to_lisp(gold_tree_reordered) out_tensor_, out_tokens_ = self.query_encoder.convert(out_, return_what="tensor,tokens") if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree_reordered], inp_tensor[None, :], out_tensor_[None, :], [inp_tokens], [out_tokens_], self.sentence_encoder.vocab, self.query_encoder.vocab, token_specs=self.token_specs) if split not in self.data: self.data[split] = [] self.data[split].append(state) numlins += 1 numlins_counts[numlins] += 1 maxlins = max(maxlins, numlins) else: if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab, token_specs=self.token_specs) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, len(out_tensor)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out
def build_data(self, inputs: Iterable[str], outputs: Iterable[str], splits: Iterable[str], unktokens: Set[str] = None): gold_map = None maxlen_in, maxlen_out = 0, 0 if unktokens is not None: gold_map = torch.arange(0, self.query_encoder.vocab.number_of_ids()) for rare_token in unktokens: gold_map[self.query_encoder.vocab[rare_token]] = \ self.query_encoder.vocab[self.query_encoder.vocab.unktoken] for inp, out, split in zip(inputs, outputs, splits): inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(out) assert (gold_tree is not None) out_tensor, out_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") if gold_map is not None: out_tensor = gold_map[out_tensor] state = TreeDecoderState([inp], [gold_tree], inp_tensor[None, :], out_tensor[None, :], [inp_tokens], [out_tokens], self.sentence_encoder.vocab, self.query_encoder.vocab, token_specs=self.token_specs) if split == "train" and self.reorder_random is True: gold_tree_ = tensor2tree(out_tensor, self.query_encoder.vocab) random_gold_tree = random.choice( get_tree_permutations(gold_tree_, orderless={"and"})) out_ = tree_to_lisp(random_gold_tree) out_tensor_, out_tokens_ = self.query_encoder.convert( out_, return_what="tensor,tokens") if gold_map is not None: out_tensor_ = gold_map[out_tensor_] state.gold_tensor = out_tensor_[None] if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, len(out_tensor)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out
def test_forward(self, x: torch.Tensor, gold: torch.Tensor = None ): # --> implement how decoder operates end-to-end preds, prednll, maxmaxnll, stepsused = self.get_prediction(x) def tensor_to_trees(x, vocab: Vocab): xstrs = [ vocab.tostr(x[i]).replace("@START@", "") for i in range(len(x)) ] xstrs = [re.sub("::\d+", "", xstr) for xstr in xstrs] trees = [] for xstr in xstrs: # drop everything after @END@, if present xstr = xstr.split("@END@") xstr = xstr[0] # add an opening parentheses if not there xstr = xstr.strip() if len(xstr) == 0 or xstr[0] != "(": xstr = "(" + xstr # balance closing parentheses parenthese_imbalance = xstr.count("(") - xstr.count(")") xstr = xstr + ")" * max(0, parenthese_imbalance ) # append missing closing parentheses xstr = "(" * -min( 0, parenthese_imbalance ) + xstr # prepend missing opening parentheses try: tree = taglisp_to_tree(xstr) if isinstance( tree, tuple) and len(tree) == 2 and tree[0] is None: tree = None except Exception as e: tree = None trees.append(tree) return trees # compute loss and metrics gold_trees = tensor_to_trees(gold, vocab=self.vocab) pred_trees = tensor_to_trees(preds, vocab=self.vocab) treeaccs = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(gold_trees, pred_trees) ] ret = { "treeacc": torch.tensor(treeaccs).to(x.device), "stepsused": stepsused } # compute bleu scores bleus = [] lcsf1s = [] for gold_tree, pred_tree in zip(gold_trees, pred_trees): if pred_tree is None or gold_tree is None: bleuscore = 0 lcsf1 = 0 else: gold_str = tree_to_lisp(gold_tree) pred_str = tree_to_lisp(pred_tree) bleuscore = sentence_bleu([gold_str.split(" ")], pred_str.split(" ")) lcsn = lcs(gold_str, pred_str) lcsrec = lcsn / len(gold_str) lcsprec = lcsn / len(pred_str) lcsf1 = 2 * lcsrec * lcsprec / (lcsrec + lcsprec) bleus.append(bleuscore) lcsf1s.append(lcsf1) bleus = torch.tensor(bleus).to(x.device) ret["bleu"] = bleus ret["lcsf1"] = torch.tensor(lcsf1s).to(x.device) d, logits = self.train_forward(x, gold) nll, acc, elemacc = d["loss"], d["acc"], d["elemacc"] ret["nll"] = nll ret["acc"] = acc ret["elemacc"] = elemacc # d, logits = self.train_forward(x, preds[:, 1:]) # decnll = d["loss"] # ret["decnll"] = decnll ret["decnll"] = prednll ret["maxmaxnll"] = maxmaxnll return ret, pred_trees
def tree_to_lisptokens(self, x: Tree): xstr = tree_to_lisp(x) xstr = xstr.replace("(", " ( ").replace(")", " ) ") xstr = re.sub("\s+", " ", xstr) xstr = xstr.split(" ") return xstr