def get_init_state(self, inpseqs=None, y_in=None) -> Tuple[ATree, Dict]: """ Encodes inpseqs and creates new states """ assert (y_in is None ) # starting decoding from non-vanilla is not supported yet # encode inpseqs encmask = (inpseqs != 0) encs = self.bert_model(inpseqs)[0] if self.adapter is not None: encs = self.adapter(encs) # create trees batsize = inpseqs.size(0) trees = [ATree("@START@", []) for _ in range(batsize)] return trees, {"enc": encs, "encmask": encmask}
def convert_chosen_actions_from_str_to_node(x: ATree, c=0): """ a tree with chosen actions as strings and gold is transformed to chosen actions from gold """ if x.is_open and x._chosen_action is not None: assert (isinstance(x._chosen_action, str)) for gold_action in x.gold_actions: if isinstance(gold_action, ATree): if gold_action.label() == x._chosen_action: x._chosen_action = gold_action break elif isinstance(gold_action, str): if gold_action == x._chosen_action: x._chosen_action = gold_action break # assert(isinstance(x._chosen_action, ATree)) c = c + 1 children = [] for child in x: child, c = convert_chosen_actions_from_str_to_node(child, c) children.append(child) x[:] = children if c > 1: assert (c <= 1) return x, c
def assign_gold_actions(x: ATree, mode="default"): """ :param x: :param mode: "default" (all) or "ltr" (only first one) :return: """ """ assigns actions that can be taken at every node of the given tree """ for xe in x: assign_gold_actions(xe, mode=mode) if not x.is_open: x.gold_actions = [] else: if x.label() == ")" or x.label() == "(": x.is_open = False x.gold_actions = [] elif x.label() == "@SLOT@": if len(x.parent) == 1: raise Exception() # get this slots's siblings x.gold_actions = [] xpos = child_number_of(x) if xpos == 0: leftsibling = None leftsibling_nr = None else: leftsibling = x.parent[xpos - 1] leftsibling_nr = child_number_of(leftsibling.align) if xpos == len(x.parent) - 1: rightsibling = None rightsibling_nr = None else: rightsibling = x.parent[xpos + 1] rightsibling_nr = child_number_of(rightsibling.align) if leftsibling is None and rightsibling is None: # slot is only child, can use any descendant x.gold_actions = x.parent.align.descendants if mode == "ltr" and len(x.gold_actions) > 0: x.gold_actions = [x.gold_actions[0]] assert (False ) # should not happen if deletion actions are not used else: p = leftsibling.align.parent if leftsibling is not None else rightsibling.align.parent slicefrom = leftsibling_nr + 1 if leftsibling_nr is not None else None slicer = slice(slicefrom, rightsibling_nr) x.gold_actions = p[slicer] if mode == "ltr" and len(x.gold_actions) > 0: x.gold_actions = [x.gold_actions[0]] if len(x.gold_actions) == 0: x.gold_actions = ["@CLOSE@"] else: # not a sibling slot ("@SLOT@"), not a "(" or ")" x.gold_actions = [] if len(x) == 0: x.gold_actions = list(x.align._descendants) if mode == "ltr" and len(x.gold_actions) > 0: x.gold_actions = [x.gold_actions[0]] else: realchildren = [xe for xe in x if xe.label() != "@SLOT@"] childancestors = realchildren[0].align._ancestors[::-1] for child in realchildren: assert (childancestors == child.align._ancestors[::-1]) for ancestor in childancestors: if ancestor is x.align: break else: x.gold_actions.append(ancestor) if mode == "ltr" and len(x.gold_actions) > 0: x.gold_actions = [x.gold_actions[0]] if len(x.gold_actions) == 0 and x.is_open: x.gold_actions = ["@CLOSE@"] if len(x.gold_actions) > 0: # x._chosen_action = x.gold_actions[0] x._chosen_action = random.choice(x.gold_actions) else: x._chosen_action = None return x
def load_ds(domain="restaurants", nl_mode="bert-base-uncased", trainonvalid=False, noreorder=False): """ Creates a dataset of examples which have * NL question and tensor * original FL tree * reduced FL tree with slots (this is randomly generated) * tensor corresponding to reduced FL tree with slots * mask specifying which elements in reduced FL tree are terminated * 2D gold that specifies whether a token/action is in gold for every position (compatibility with MML!) """ orderless = {"op:and", "SW:concat"} # only use in eval!! ds = OvernightDatasetLoader().load(domain=domain, trainonvalid=trainonvalid) ds = ds.map(lambda x: (x[0], ATree("@START@", [x[1]]), x[2])) if not noreorder: ds = ds.map(lambda x: (x[0], reorder_tree(x[1], orderless=orderless), x[2])) vocab = Vocab(padid=0, startid=2, endid=3, unkid=1) vocab.add_token("@START@", seen=np.infty) vocab.add_token( "@CLOSE@", seen=np.infty ) # only here for the action of closing an open position, will not be seen at input vocab.add_token( "@OPEN@", seen=np.infty ) # only here for the action of opening a closed position, will not be seen at input vocab.add_token( "@REMOVE@", seen=np.infty ) # only here for deletion operations, won't be seen at input vocab.add_token( "@REMOVESUBTREE@", seen=np.infty ) # only here for deletion operations, won't be seen at input vocab.add_token("@SLOT@", seen=np.infty) # will be seen at input, can't be produced! nl_tokenizer = BertTokenizer.from_pretrained(nl_mode) # for tok, idd in nl_tokenizer.vocab.items(): # vocab.add_token(tok, seen=np.infty) # all wordpieces are added for possible later generation tds, vds, xds = ds[lambda x: x[2] == "train"], \ ds[lambda x: x[2] == "valid"], \ ds[lambda x: x[2] == "test"] seqenc = SequenceEncoder( vocab=vocab, tokenizer=lambda x: extract_info(x, onlytokens=True), add_start_token=False, add_end_token=False) for example in tds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=True) for example in vds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=False) for example in xds.examples: query = example[1] seqenc.inc_build_vocab(query, seen=False) seqenc.finalize_vocab(min_freq=0) def mapper(x): nl = x[0] fl = x[1] fltoks = extract_info(fl, onlytokens=True) seq = seqenc.convert(fltoks, return_what="tensor") ret = (nl_tokenizer.encode(nl, return_tensors="pt")[0], seq) return ret tds_seq = tds.map(mapper) vds_seq = vds.map(mapper) xds_seq = xds.map(mapper) return tds_seq, vds_seq, xds_seq, nl_tokenizer, seqenc, orderless
def execute_chosen_actions(x: ATree, _budget=[np.infty], mode="full"): if x._chosen_action is None or not x.is_open: iterr = list(x) for xe in iterr: execute_chosen_actions(xe, _budget=_budget, mode=mode) return x if x.label() == "(": # insert a parent before current parent pass elif x.label() == ")": pass elif x.label() == "@SLOT@": if x._chosen_action == "@CLOSE@": del x.parent[child_number_of(x)] # if parentheses became empty, remove _budget[0] += 1 if len(x.parent) == 2 and x.parent[0].label( ) == "(" and x.parent[1].label() == ")": x.parent[:] = [] else: if _budget[0] <= 0: return x if isinstance(x._chosen_action, Tree): x.set_label(x._chosen_action.label()) else: x.set_label(x._chosen_action) if isinstance(x._chosen_action, Tree): x.align = x._chosen_action x.is_open = True leftslot = ATree("@SLOT@", [], is_open=True) leftslot.parent = x.parent rightslot = ATree("@SLOT@", [], is_open=True) rightslot.parent = x.parent if mode != "ltr": x.parent.insert(child_number_of(x), leftslot) _budget[0] -= 1 x.parent.insert(child_number_of(x) + 1, rightslot) _budget[0] -= 1 else: iterr = list(x) for xe in iterr: execute_chosen_actions(xe, _budget=_budget, mode=mode) if _budget[0] <= 0: return x if x._chosen_action == "@CLOSE@": x.is_open = False # this node can't generate children anymore else: # X(A, B, C) -> X _( _@SLOT _Y (A, B, C) [_@SLOT] _) # add child, with "(" and ")" and "@SLOT@" nodes if isinstance(x._chosen_action, Tree): newnode = ATree(x._chosen_action.label(), []) else: newnode = ATree(x._chosen_action, []) newnode.is_open = True if mode == "ltr": x.is_open = False if isinstance(x._chosen_action, Tree): newnode.align = x._chosen_action newnode.parent = x newnode[:] = x[:] for xe in newnode: xe.parent = newnode leftslot = ATree("@SLOT@", [], is_open=True) leftslot.parent = newnode.parent rightslot = ATree("@SLOT@", [], is_open=True) rightslot.parent = newnode.parent if mode != "ltr": x[:] = [leftslot, newnode, rightslot] _budget[0] -= 3 else: x[:] = [newnode, rightslot] _budget[0] -= 2 return x