def forward(self, probs, predactions, gold, x: State = None) -> Dict: # get tensor from state penalty_vec = self.getter(x) assert (penalty_vec.dim() == 1 and penalty_vec.size(0) == probs.size(0)) if self.reduction in ("mean", "default"): penalty = penalty_vec.mean() elif self.reduction == "sum": penalty = penalty_vec.sum() elif self.reduction in ("none", None): penalty = penalty_vec else: raise Exception(f"unknown reduction mode: {self.reduction}") ret = penalty * q.v(self.weight) ret = ret * self.contrib return {"loss": ret, self._name: ret}
def forward(self, probs, gold): """ :param probs: (batsize, ..., vocsize) logits :param gold: (batsize, ..., ) int ids of correct class :return: """ _prob_mask_crit = -np.infty if self.mode in "logits logprobs".split( ) else 0 lsv = q.v(self.smoothing) # get value of label smoothing hyperparam assert (lsv >= 0 and lsv <= 1) prob_mask = (probs > _prob_mask_crit).float( ) # (batsize, ..., vocsize) where probs are > 0, reverse engineering a -infty mask applied outside prob_mask_weights = lsv / prob_mask.sum(-1, keepdim=True) _gold = torch.ones_like(probs) * prob_mask_weights * prob_mask _gold.scatter_(-1, gold.unsqueeze(-1), (1 - lsv) + prob_mask_weights) # (batsize, ..., vocsize) probs assert ((_gold.sum(-1) - torch.ones_like(gold).float()).norm().cpu().item() < 1e-5) logprobs = self.sm(probs) if self.mode == "logits" else ( probs if self.mode == "logprobs" else torch.log(probs)) kl_divs = self.kl(logprobs, _gold.detach()) # kl_divs = inf2zero(kl_divs) kl_div = kl_divs.sum(-1) # (batsize, ...) kl div per element if self.weight is not None: kl_div = kl_div * self.weight[gold] mask = DiscreteLoss.get_ignore_mask(gold, self.ignore_indices).float() kl_div = kl_div * mask ret = kl_div.sum() if self.reduction in ["elementwise_mean", "mean"]: total = mask.sum() ret = ret / total elif self.reduction == "none": ret = kl_div return ret
def forward(self, inpseqs: torch.Tensor = None, gold: torch.Tensor = None, mode: str = None, maxsteps: int = None, **kw): """ """ maxsteps = maxsteps if maxsteps is not None else self.maxsteps mode = mode if mode is not None else self.mode device = next(self.parameters()).device trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None) if gold is not None: goldtrees = [ tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold) ] goldtrees = [ add_descendants_ancestors(goldtree) for goldtree in goldtrees ] for i in range(len(trees)): trees[i].align = goldtrees[i] i = 0 if self.training: trees = [assign_gold_actions(tree, mode=mode) for tree in trees] # choices = [deepcopy(trees)] numsteps = [0 for _ in range(len(trees))] treesizes = [0 for _ in range(len(trees))] seqlens = [0 for _ in range(len(trees))] allmetrics = {} allmetrics["loss"] = [] allmetrics["ce"] = [] allmetrics["elemrecall"] = [] allmetrics["allrecall"] = [] allmetrics["anyrecall"] = [] allmetrics["lowestentropyrecall"] = [] examplemasks = [] while not all([all_terminated(tree) for tree in trees]) and i < maxsteps: # go from tree to tensors, seq = [] openmask = [] stepgold = [] for j, tree in enumerate(trees): if not all_terminated(tree): numsteps[j] += 1 treesizes[j] = tree_size(tree) fltoks, openmask_e, stepgold_e = extract_info(tree) seqlens[j] = len(fltoks) seq_e = self.seqenc.convert(fltoks, return_what="tensor") seq.append(seq_e) openmask.append(torch.tensor(openmask_e)) if self.training: stepgold_e_tensor = torch.zeros( seq_e.size(0), self.seqenc.vocab.number_of_ids()) for j, stepgold_e_i in enumerate(stepgold_e): for golde in stepgold_e_i: stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1 stepgold.append(stepgold_e_tensor) seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device) openmask = torch.stack(q.pad_tensors(openmask, 0, False), 0).to(device) if self.training: stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0), 0).to(device) if self.training: examplemask = (stepgold != 0).any(-1).any(-1) examplemasks.append(examplemask) # feed to tagger, probs = self.tagger(seq, openmask=openmask, **context) if self.training: # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions) ce = self.ce(probs, stepgold, mask=openmask) elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall( probs, stepgold, mask=openmask) allmetrics["loss"].append(ce) allmetrics["ce"].append(ce) allmetrics["elemrecall"].append(elemrecall) allmetrics["allrecall"].append(allrecall) allmetrics["anyrecall"].append(anyrecall) allmetrics["lowestentropyrecall"].append(lowestentropyrecall) # get best predictions, _, best_actions = probs.max(-1) entropies = torch.softmax(probs, -1).clamp_min(1e-6) entropies = -(entropies * torch.log(entropies)).sum(-1) if self.training: newprobs = torch.softmax(probs, -1) * stepgold # mask using gold newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min( 1e-6) # renormalize uniform = stepgold uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6) newprobs = newprobs * (1 - q.v( self.uniformfactor)) + uniform * q.v(self.uniformfactor) noprobsmask = (newprobs != 0).any(-1, keepdim=True).float() zeroprobs = torch.zeros_like(newprobs) zeroprobs[:, :, 0] = 1 newprobs = newprobs * noprobsmask + (1 - noprobsmask) * zeroprobs sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\ .sample().view(newprobs.size(0), newprobs.size(1)) taken_actions = sampled_gold_actions else: taken_actions = best_actions # attach chosen actions and entropies to existing trees, for tree, actions_e, entropies_e in zip(trees, taken_actions, entropies): actions_e = list(actions_e.detach().cpu().numpy()) actions_e = [self.seqenc.vocab(xe) for xe in actions_e] entropies_e = list(entropies_e.detach().cpu().numpy()) self.attach_info_to_tree(tree, _chosen_action=actions_e, _entropy=entropies_e) # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e) # for seqe, openmaske, actione, entropies_e # in zip(list(seq), list(openmask), list(taken_actions), list(entropies))] # and execute, trees_ = [] for tree in trees: if tree_size(tree) < self.max_tree_size: markmode = mode if not self.training else "train-" + mode tree = mark_for_execution(tree, mode=markmode, entropylimit=self.entropylimit) budget = [self.max_tree_size - tree_size(tree)] if self.training: tree, _ = convert_chosen_actions_from_str_to_node(tree) tree = execute_chosen_actions(tree, _budget=budget, mode=mode) if self.training: tree = assign_gold_actions(tree, mode=mode) trees_.append(tree) trees = trees_ i += 1 # then repeat until all terminated # after done decoding, if gold is given, run losses, else return just predictions ret = {} ret["seqlens"] = torch.tensor(seqlens).float() ret["treesizes"] = torch.tensor(treesizes).float() ret["numsteps"] = torch.tensor(numsteps).float() if self.training: assert (len(examplemasks) > 0) assert (len(allmetrics["loss"]) > 0) allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()} examplemasks = torch.stack(examplemasks, 1).float() _allmetrics = {} for k, v in allmetrics.items(): _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum( 1).clamp_min(1e-6) allmetrics = _allmetrics ret.update(allmetrics) if gold is not None: goldtrees = [ tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold) ] goldtrees = [simplify_tree_for_eval(x) for x in goldtrees] predtrees = [simplify_tree_for_eval(x) for x in trees] ret["treeacc"] = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(goldtrees, predtrees) ] ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device) return ret, trees
def forward(self, x:State): if not "mstate" in x: x.mstate = State() x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device) mstate = x.mstate init_states = [] if not "ctx" in mstate: # encode input inptensor = x.inp_tensor mask = inptensor != 0 inpembs = self.inp_emb(inptensor) # inpembs = self.dropout(inpembs) inpenc, final_encs = self.inp_enc(inpembs, mask) for i, final_enc in enumerate(final_encs): # iter over layers _fenc = self.enc_to_dec[i](final_enc[0]) init_states.append(_fenc) mstate.ctx = inpenc mstate.ctx_mask = mask if self.training and q.v(self.beta) < 1: # sample one of the orders golds = x._gold_tensors goldsmask = (golds != 0).any(-1).float() numgolds = goldsmask.sum(-1) gold_select_prob = torch.ones_like(goldsmask) * goldsmask / numgolds[:, None] selector = gold_select_prob.multinomial(1)[:, 0] gold = golds.gather(1, selector[:, None, None].repeat(1, 1, golds.size(2)))[:, 0] # interpolate with original gold original_gold = x.gold_tensor beta_selector = (torch.rand_like(numgolds) <= q.v(self.beta)).long() gold_ = original_gold * beta_selector[:, None] + gold * (1 - beta_selector[:, None]) x.gold_tensor = gold_ ctx = mstate.ctx ctx_mask = mstate.ctx_mask emb = self.out_emb(x.prev_actions) if not "rnnstate" in mstate: init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device) # uncomment next line to initialize decoder state with last state of encoder # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc if len(init_states) == init_rnn_state.h.size(1): init_rnn_state.h = torch.stack(init_states, 1).contiguous() mstate.rnnstate = init_rnn_state if "prev_summ" not in mstate: # mstate.prev_summ = torch.zeros_like(ctx[:, 0]) mstate.prev_summ = final_encs[-1][0] _emb = emb if self.feedatt == True: _emb = torch.cat([_emb, mstate.prev_summ], 1) enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate) mstate.rnnstate = new_rnnstate alphas, summ, scores = self.att(enc, ctx, ctx_mask) mstate.prev_summ = summ enc = torch.cat([enc, summ], -1) if self.training: out_mask = None else: out_mask = x.get_out_mask(device=enc.device) if self.nocopy is True: outs = self.out_lin(enc, out_mask) else: outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask) outs = (outs,) if not q.issequence(outs) else outs # _, preds = outs.max(-1) if self.store_attn: if "stored_attentions" not in x: x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device) x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1) mstate.decoding_step = mstate.decoding_step + 1 return outs[0], x
def on_start_train_epoch(): penweight.v /= 1.2 print(q.v(penweight))