def collate_fn(data: Iterable): goldmaxlen = 0 inpmaxlen = 0 data = [state.make_copy(detach=True, deep=True) for state in data] for state in data: goldmaxlen = max(goldmaxlen, state.gold_tensor.size(1)) inpmaxlen = max(inpmaxlen, state.inp_tensor.size(1)) goldmaxlen = max(goldmaxlen, state.candtensors.size(-1)) inp_tensors = q.pad_tensors([state.inp_tensor for state in data], 1, 0) gold_tensors = q.pad_tensors([state.gold_tensor for state in data], 1, 0) candtensors = q.pad_tensors([state.candtensors for state in data], 2, 0) alignments = q.pad_tensors([state.alignments for state in data], 2, 0) alignment_entropies = q.pad_tensors( [state.alignment_entropies for state in data], 2, 0) for i, state in enumerate(data): state.inp_tensor = inp_tensors[i] state.gold_tensor = gold_tensors[i] state.candtensors = candtensors[i] state.alignments = alignments[i] state.alignment_entropies = alignment_entropies[i] ret = data[0].merge(data) return ret
def collate_fn(x, pad_value_nl=0, pad_value_fl=0): y = list(zip(*x)) assert (len(y) == 3) y[0] = torch.stack(q.pad_tensors(y[0], 0, pad_value_nl), 0) y[1] = torch.stack(q.pad_tensors(y[1], 0, pad_value_nl), 0) y[2] = torch.stack(q.pad_tensors(y[2], 0, pad_value_fl), 0) return tuple(y)
def collate_fn(x, pad_value=0): y = list(zip(*x)) assert (len(y) == 4) y[0] = torch.stack(q.pad_tensors(y[0], 0, pad_value), 0) y[1] = torch.stack(q.pad_tensors(y[1], 0, pad_value), 0) y[2] = torch.stack(q.pad_tensors(y[2], 0, False), 0) y[3] = torch.stack(q.pad_tensors(y[3], 0, pad_value), 0) return y
def get_arrays_to_save(x:List[State]): inp_tensors = torch.cat(q.pad_tensors([xe.inp_tensor for xe in x], 1), 0) followed_actions = torch.cat(q.pad_tensors([xe.followed_actions for xe in x], 1), 0) gold_tensors = torch.cat(q.pad_tensors([xe.gold_tensor for xe in x], 1), 0) attentions = torch.cat(q.pad_tensors([xe.stored_attentions for xe in x], (1, 2)), 0) return { "inp_tensor": inp_tensors, "gold_tensor": gold_tensors, "followed_actions": followed_actions, "attentions": attentions }
def pad_and_default_collate(x, pad_value=0): y = list(zip(*x)) for i, yi in enumerate(y): if isinstance(yi[0], torch.LongTensor) and yi[0].dim() == 1: y[i] = q.pad_tensors(yi, 0, pad_value) x = list(zip(*y)) ret = default_collate(x) return ret
def autocollate(x, pad_value=0): y = list(zip(*x)) for i, yi in enumerate(y): if isinstance(yi[0], torch.LongTensor) and yi[0].dim() == 1: y[i] = q.pad_tensors(yi, 0, pad_value) for i, yi in enumerate(y): if isinstance(yi[0], torch.Tensor): yi = [yij[None] for yij in yi] y[i] = torch.cat(yi, 0) return y
def build_data(self, examples: Iterable[dict], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 for example, split in zip(examples, splits): inp, out = " ".join(example["sentence"]), " ".join(example["gold"]) inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(" ".join(example["gold"][:-1])) if not isinstance(gold_tree, Tree): assert (gold_tree is not None) gold_tensor, gold_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], [] candidate_align_entropies = [] candidate_trees = [] candidate_same = [] for cand in example["candidates"]: cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]), None) if cand_tree is None: cand_tree = Tree("@UNK@", []) assert (cand_tree is not None) cand_tensor, cand_tokens = self.query_encoder.convert( " ".join(cand["tokens"]), return_what="tensor,tokens") candidate_tensors.append(cand_tensor) candidate_tokens.append(cand_tokens) candidate_align_tensors.append(torch.tensor( cand["alignments"])) candidate_align_entropies.append( torch.tensor(cand["align_entropies"])) candidate_trees.append(cand_tree) candidate_same.append( are_equal_trees(cand_tree, gold_tree, orderless={"and", "or"}, unktoken="@NOUNKTOKENHERE@")) candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0), 0) candidate_align_tensor = torch.stack( q.pad_tensors(candidate_align_tensors, 0), 0) candidate_align_entropy = torch.stack( q.pad_tensors(candidate_align_entropies, 0), 0) candidate_same = torch.tensor(candidate_same) state = RankState( inp_tensor[None, :], gold_tensor[None, :], candidate_tensor[None, :, :], candidate_same[None, :], candidate_align_tensor[None, :], candidate_align_entropy[None, :], self.sentence_encoder.vocab, self.query_encoder.vocab, ) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, candidate_tensor.size(-1), gold_tensor.size(-1)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out
def forward(self, inpseqs: torch.Tensor = None, gold: torch.Tensor = None, mode: str = None, maxsteps: int = None, **kw): """ """ maxsteps = maxsteps if maxsteps is not None else self.maxsteps mode = mode if mode is not None else self.mode device = next(self.parameters()).device trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None) if gold is not None: goldtrees = [ tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold) ] goldtrees = [ add_descendants_ancestors(goldtree) for goldtree in goldtrees ] for i in range(len(trees)): trees[i].align = goldtrees[i] i = 0 if self.training: trees = [assign_gold_actions(tree, mode=mode) for tree in trees] # choices = [deepcopy(trees)] numsteps = [0 for _ in range(len(trees))] treesizes = [0 for _ in range(len(trees))] seqlens = [0 for _ in range(len(trees))] allmetrics = {} allmetrics["loss"] = [] allmetrics["ce"] = [] allmetrics["elemrecall"] = [] allmetrics["allrecall"] = [] allmetrics["anyrecall"] = [] allmetrics["lowestentropyrecall"] = [] examplemasks = [] while not all([all_terminated(tree) for tree in trees]) and i < maxsteps: # go from tree to tensors, seq = [] openmask = [] stepgold = [] for j, tree in enumerate(trees): if not all_terminated(tree): numsteps[j] += 1 treesizes[j] = tree_size(tree) fltoks, openmask_e, stepgold_e = extract_info(tree) seqlens[j] = len(fltoks) seq_e = self.seqenc.convert(fltoks, return_what="tensor") seq.append(seq_e) openmask.append(torch.tensor(openmask_e)) if self.training: stepgold_e_tensor = torch.zeros( seq_e.size(0), self.seqenc.vocab.number_of_ids()) for j, stepgold_e_i in enumerate(stepgold_e): for golde in stepgold_e_i: stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1 stepgold.append(stepgold_e_tensor) seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device) openmask = torch.stack(q.pad_tensors(openmask, 0, False), 0).to(device) if self.training: stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0), 0).to(device) if self.training: examplemask = (stepgold != 0).any(-1).any(-1) examplemasks.append(examplemask) # feed to tagger, probs = self.tagger(seq, openmask=openmask, **context) if self.training: # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions) ce = self.ce(probs, stepgold, mask=openmask) elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall( probs, stepgold, mask=openmask) allmetrics["loss"].append(ce) allmetrics["ce"].append(ce) allmetrics["elemrecall"].append(elemrecall) allmetrics["allrecall"].append(allrecall) allmetrics["anyrecall"].append(anyrecall) allmetrics["lowestentropyrecall"].append(lowestentropyrecall) # get best predictions, _, best_actions = probs.max(-1) entropies = torch.softmax(probs, -1).clamp_min(1e-6) entropies = -(entropies * torch.log(entropies)).sum(-1) if self.training: newprobs = torch.softmax(probs, -1) * stepgold # mask using gold newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min( 1e-6) # renormalize uniform = stepgold uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6) newprobs = newprobs * (1 - q.v( self.uniformfactor)) + uniform * q.v(self.uniformfactor) noprobsmask = (newprobs != 0).any(-1, keepdim=True).float() zeroprobs = torch.zeros_like(newprobs) zeroprobs[:, :, 0] = 1 newprobs = newprobs * noprobsmask + (1 - noprobsmask) * zeroprobs sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\ .sample().view(newprobs.size(0), newprobs.size(1)) taken_actions = sampled_gold_actions else: taken_actions = best_actions # attach chosen actions and entropies to existing trees, for tree, actions_e, entropies_e in zip(trees, taken_actions, entropies): actions_e = list(actions_e.detach().cpu().numpy()) actions_e = [self.seqenc.vocab(xe) for xe in actions_e] entropies_e = list(entropies_e.detach().cpu().numpy()) self.attach_info_to_tree(tree, _chosen_action=actions_e, _entropy=entropies_e) # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e) # for seqe, openmaske, actione, entropies_e # in zip(list(seq), list(openmask), list(taken_actions), list(entropies))] # and execute, trees_ = [] for tree in trees: if tree_size(tree) < self.max_tree_size: markmode = mode if not self.training else "train-" + mode tree = mark_for_execution(tree, mode=markmode, entropylimit=self.entropylimit) budget = [self.max_tree_size - tree_size(tree)] if self.training: tree, _ = convert_chosen_actions_from_str_to_node(tree) tree = execute_chosen_actions(tree, _budget=budget, mode=mode) if self.training: tree = assign_gold_actions(tree, mode=mode) trees_.append(tree) trees = trees_ i += 1 # then repeat until all terminated # after done decoding, if gold is given, run losses, else return just predictions ret = {} ret["seqlens"] = torch.tensor(seqlens).float() ret["treesizes"] = torch.tensor(treesizes).float() ret["numsteps"] = torch.tensor(numsteps).float() if self.training: assert (len(examplemasks) > 0) assert (len(allmetrics["loss"]) > 0) allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()} examplemasks = torch.stack(examplemasks, 1).float() _allmetrics = {} for k, v in allmetrics.items(): _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum( 1).clamp_min(1e-6) allmetrics = _allmetrics ret.update(allmetrics) if gold is not None: goldtrees = [ tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold) ] goldtrees = [simplify_tree_for_eval(x) for x in goldtrees] predtrees = [simplify_tree_for_eval(x) for x in trees] ret["treeacc"] = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(goldtrees, predtrees) ] ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device) return ret, trees
def forward(self, inpseqs: torch.Tensor = None, y_in: torch.Tensor = None, gold: torch.Tensor = None, mode: str = None, maxsteps: int = None, **kw): """ """ maxsteps = maxsteps if maxsteps is not None else self.maxsteps mode = mode if mode is not None else self.mode device = next(self.parameters()).device trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=y_in) i = 0 while not all([all_terminated(tree) for tree in trees]) and i < maxsteps: # go from tree to tensors, tensors = [] masks = [] for tree in trees: fltoks, openmask = extract_info(tree, nogold=True) seq = self.seqenc.convert(fltoks, return_what="tensor") tensors.append(seq) masks.append(torch.tensor(openmask)) seq = torch.stack(q.pad_tensors(tensors, 0), 0).to(device) openmask = torch.stack(q.pad_tensors(masks, 0, False), 0).to(device) # feed to tagger, probs = self.tagger(seq, openmask=openmask, **context) # get best predictions, _, best = probs.max(-1) entropies = torch.softmax(probs, -1).clamp_min(1e-6) entropies = -(entropies * torch.log(entropies)).sum(-1) # convert to trees, trees = [ tensors_to_tree(seqe, openmask=openmaske, actions=beste, D=self.seqenc.vocab, entropies=entropies_e) for seqe, openmaske, beste, entropies_e in zip( list(seq), list(openmask), list(best), list(entropies)) ] # and execute, trees_ = [] for tree in trees: if tree_size(tree) < self.max_tree_size: tree = mark_for_execution(tree, mode=mode) budget = [self.max_tree_size - tree_size(tree)] tree = execute_chosen_actions(tree, _budget=budget, mode=mode) trees_.append(tree) trees = trees_ i += 1 # then repeat until all terminated # after done decoding, if gold is given, run losses, else return just predictions ret = {} if gold is not None: goldtrees = [ tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold) ] goldtrees = [simplify_tree_for_eval(x) for x in goldtrees] predtrees = [simplify_tree_for_eval(x) for x in trees] ret["treeacc"] = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(goldtrees, predtrees) ] ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device) return ret, trees
def forward(self, x: State): if not "mstate" in x: x.mstate = State() mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state( emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate if "prevstates" not in mstate: _ctx = ctx _ctx_mask = ctx_mask mstate.prevstates = enc[:, None, :] else: _ctx = torch.cat([ctx, mstate.prevstates], 1) _ctx_mask = torch.cat([ ctx_mask, torch.ones(mstate.prevstates.size(0), mstate.prevstates.size(1), dtype=ctx_mask.dtype, device=ctx_mask.device) ], 1) mstate.prevstates = torch.cat([mstate.prevstates, enc[:, None, :]], 1) alphas, summ, scores = self.att(enc, _ctx, _ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs, ) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) atts = q.pad_tensors( [x.stored_attentions, alphas.detach()[:, None, :]], 2, 0) x.stored_attentions = torch.cat(atts, 1) return outs[0], x