예제 #1
0
파일: eval.py 프로젝트: lukovnikov/parseq
 def forward(self, probs, predactions, gold, x: State = None) -> Dict:
     # get tensor from state
     penalty_vec = self.getter(x)
     assert (penalty_vec.dim() == 1
             and penalty_vec.size(0) == probs.size(0))
     if self.reduction in ("mean", "default"):
         penalty = penalty_vec.mean()
     elif self.reduction == "sum":
         penalty = penalty_vec.sum()
     elif self.reduction in ("none", None):
         penalty = penalty_vec
     else:
         raise Exception(f"unknown reduction mode: {self.reduction}")
     ret = penalty * q.v(self.weight)
     ret = ret * self.contrib
     return {"loss": ret, self._name: ret}
예제 #2
0
    def forward(self, probs, gold):
        """
        :param probs:   (batsize, ..., vocsize) logits
        :param gold:    (batsize, ..., ) int ids of correct class
        :return:
        """
        _prob_mask_crit = -np.infty if self.mode in "logits logprobs".split(
        ) else 0
        lsv = q.v(self.smoothing)  # get value of label smoothing hyperparam
        assert (lsv >= 0 and lsv <= 1)
        prob_mask = (probs > _prob_mask_crit).float(
        )  # (batsize, ..., vocsize) where probs are > 0, reverse engineering a -infty mask applied outside
        prob_mask_weights = lsv / prob_mask.sum(-1, keepdim=True)
        _gold = torch.ones_like(probs) * prob_mask_weights * prob_mask
        _gold.scatter_(-1, gold.unsqueeze(-1), (1 - lsv) +
                       prob_mask_weights)  # (batsize, ..., vocsize) probs
        assert ((_gold.sum(-1) -
                 torch.ones_like(gold).float()).norm().cpu().item() < 1e-5)

        logprobs = self.sm(probs) if self.mode == "logits" else (
            probs if self.mode == "logprobs" else torch.log(probs))
        kl_divs = self.kl(logprobs, _gold.detach())
        # kl_divs = inf2zero(kl_divs)
        kl_div = kl_divs.sum(-1)  # (batsize, ...) kl div per element

        if self.weight is not None:
            kl_div = kl_div * self.weight[gold]

        mask = DiscreteLoss.get_ignore_mask(gold, self.ignore_indices).float()
        kl_div = kl_div * mask
        ret = kl_div.sum()
        if self.reduction in ["elementwise_mean", "mean"]:
            total = mask.sum()
            ret = ret / total
        elif self.reduction == "none":
            ret = kl_div
        return ret
    def forward(self,
                inpseqs: torch.Tensor = None,
                gold: torch.Tensor = None,
                mode: str = None,
                maxsteps: int = None,
                **kw):
        """

        """
        maxsteps = maxsteps if maxsteps is not None else self.maxsteps
        mode = mode if mode is not None else self.mode
        device = next(self.parameters()).device

        trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None)
        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [
                add_descendants_ancestors(goldtree) for goldtree in goldtrees
            ]
            for i in range(len(trees)):
                trees[i].align = goldtrees[i]

        i = 0

        if self.training:
            trees = [assign_gold_actions(tree, mode=mode) for tree in trees]
        # choices = [deepcopy(trees)]
        numsteps = [0 for _ in range(len(trees))]
        treesizes = [0 for _ in range(len(trees))]
        seqlens = [0 for _ in range(len(trees))]
        allmetrics = {}
        allmetrics["loss"] = []
        allmetrics["ce"] = []
        allmetrics["elemrecall"] = []
        allmetrics["allrecall"] = []
        allmetrics["anyrecall"] = []
        allmetrics["lowestentropyrecall"] = []
        examplemasks = []
        while not all([all_terminated(tree)
                       for tree in trees]) and i < maxsteps:
            # go from tree to tensors,
            seq = []
            openmask = []
            stepgold = []
            for j, tree in enumerate(trees):
                if not all_terminated(tree):
                    numsteps[j] += 1
                treesizes[j] = tree_size(tree)

                fltoks, openmask_e, stepgold_e = extract_info(tree)
                seqlens[j] = len(fltoks)
                seq_e = self.seqenc.convert(fltoks, return_what="tensor")
                seq.append(seq_e)
                openmask.append(torch.tensor(openmask_e))
                if self.training:
                    stepgold_e_tensor = torch.zeros(
                        seq_e.size(0), self.seqenc.vocab.number_of_ids())
                    for j, stepgold_e_i in enumerate(stepgold_e):
                        for golde in stepgold_e_i:
                            stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1
                    stepgold.append(stepgold_e_tensor)
            seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device)
            openmask = torch.stack(q.pad_tensors(openmask, 0, False),
                                   0).to(device)
            if self.training:
                stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0),
                                       0).to(device)

            if self.training:
                examplemask = (stepgold != 0).any(-1).any(-1)
                examplemasks.append(examplemask)

            #  feed to tagger,
            probs = self.tagger(seq, openmask=openmask, **context)

            if self.training:
                # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions)
                ce = self.ce(probs, stepgold, mask=openmask)
                elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall(
                    probs, stepgold, mask=openmask)
                allmetrics["loss"].append(ce)
                allmetrics["ce"].append(ce)
                allmetrics["elemrecall"].append(elemrecall)
                allmetrics["allrecall"].append(allrecall)
                allmetrics["anyrecall"].append(anyrecall)
                allmetrics["lowestentropyrecall"].append(lowestentropyrecall)

            #  get best predictions,
            _, best_actions = probs.max(-1)
            entropies = torch.softmax(probs, -1).clamp_min(1e-6)
            entropies = -(entropies * torch.log(entropies)).sum(-1)

            if self.training:
                newprobs = torch.softmax(probs,
                                         -1) * stepgold  # mask using gold
                newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min(
                    1e-6)  # renormalize
                uniform = stepgold
                uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6)
                newprobs = newprobs * (1 - q.v(
                    self.uniformfactor)) + uniform * q.v(self.uniformfactor)

                noprobsmask = (newprobs != 0).any(-1, keepdim=True).float()
                zeroprobs = torch.zeros_like(newprobs)
                zeroprobs[:, :, 0] = 1
                newprobs = newprobs * noprobsmask + (1 -
                                                     noprobsmask) * zeroprobs

                sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\
                    .sample().view(newprobs.size(0), newprobs.size(1))
                taken_actions = sampled_gold_actions
            else:
                taken_actions = best_actions

            #  attach chosen actions and entropies to existing trees,
            for tree, actions_e, entropies_e in zip(trees, taken_actions,
                                                    entropies):
                actions_e = list(actions_e.detach().cpu().numpy())
                actions_e = [self.seqenc.vocab(xe) for xe in actions_e]
                entropies_e = list(entropies_e.detach().cpu().numpy())
                self.attach_info_to_tree(tree,
                                         _chosen_action=actions_e,
                                         _entropy=entropies_e)
            # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e)
            #          for seqe, openmaske, actione, entropies_e
            #          in zip(list(seq), list(openmask), list(taken_actions), list(entropies))]

            #  and execute,
            trees_ = []
            for tree in trees:
                if tree_size(tree) < self.max_tree_size:
                    markmode = mode if not self.training else "train-" + mode
                    tree = mark_for_execution(tree,
                                              mode=markmode,
                                              entropylimit=self.entropylimit)
                    budget = [self.max_tree_size - tree_size(tree)]
                    if self.training:
                        tree, _ = convert_chosen_actions_from_str_to_node(tree)
                    tree = execute_chosen_actions(tree,
                                                  _budget=budget,
                                                  mode=mode)
                    if self.training:
                        tree = assign_gold_actions(tree, mode=mode)
                trees_.append(tree)

            trees = trees_
            i += 1
            #  then repeat until all terminated

        # after done decoding, if gold is given, run losses, else return just predictions

        ret = {}

        ret["seqlens"] = torch.tensor(seqlens).float()
        ret["treesizes"] = torch.tensor(treesizes).float()
        ret["numsteps"] = torch.tensor(numsteps).float()

        if self.training:
            assert (len(examplemasks) > 0)
            assert (len(allmetrics["loss"]) > 0)
            allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()}
            examplemasks = torch.stack(examplemasks, 1).float()
            _allmetrics = {}
            for k, v in allmetrics.items():
                _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum(
                    1).clamp_min(1e-6)
            allmetrics = _allmetrics
            ret.update(allmetrics)

        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [simplify_tree_for_eval(x) for x in goldtrees]
            predtrees = [simplify_tree_for_eval(x) for x in trees]
            ret["treeacc"] = [
                float(
                    are_equal_trees(gold_tree,
                                    pred_tree,
                                    orderless=ORDERLESS,
                                    unktoken="@UNK@"))
                for gold_tree, pred_tree in zip(goldtrees, predtrees)
            ]
            ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device)

        return ret, trees
예제 #4
0
    def forward(self, x:State):
        if not "mstate" in x:
            x.mstate = State()
            x.mstate.decoding_step = torch.zeros(x.inp_tensor.size(0), dtype=torch.long, device=x.inp_tensor.device)
        mstate = x.mstate
        init_states = []
        if not "ctx" in mstate:
            # encode input
            inptensor = x.inp_tensor
            mask = inptensor != 0
            inpembs = self.inp_emb(inptensor)
            # inpembs = self.dropout(inpembs)
            inpenc, final_encs = self.inp_enc(inpembs, mask)
            for i, final_enc in enumerate(final_encs):    # iter over layers
                _fenc = self.enc_to_dec[i](final_enc[0])
                init_states.append(_fenc)
            mstate.ctx = inpenc
            mstate.ctx_mask = mask

            if self.training and q.v(self.beta) < 1:    # sample one of the orders
                golds = x._gold_tensors
                goldsmask = (golds != 0).any(-1).float()
                numgolds = goldsmask.sum(-1)
                gold_select_prob = torch.ones_like(goldsmask) * goldsmask / numgolds[:, None]
                selector = gold_select_prob.multinomial(1)[:, 0]
                gold = golds.gather(1, selector[:, None, None].repeat(1, 1, golds.size(2)))[:, 0]
                # interpolate with original gold
                original_gold = x.gold_tensor
                beta_selector = (torch.rand_like(numgolds) <= q.v(self.beta)).long()
                gold_ = original_gold * beta_selector[:, None] + gold * (1 - beta_selector[:, None])
                x.gold_tensor = gold_


        ctx = mstate.ctx
        ctx_mask = mstate.ctx_mask

        emb = self.out_emb(x.prev_actions)

        if not "rnnstate" in mstate:
            init_rnn_state = self.out_rnn.get_init_state(emb.size(0), emb.device)
            # uncomment next line to initialize decoder state with last state of encoder
            # init_rnn_state[f"{len(init_rnn_state)-1}"]["c"] = final_enc
            if len(init_states) == init_rnn_state.h.size(1):
                init_rnn_state.h = torch.stack(init_states, 1).contiguous()
            mstate.rnnstate = init_rnn_state

        if "prev_summ" not in mstate:
            # mstate.prev_summ = torch.zeros_like(ctx[:, 0])
            mstate.prev_summ = final_encs[-1][0]

        _emb = emb

        if self.feedatt == True:
            _emb = torch.cat([_emb, mstate.prev_summ], 1)
        enc, new_rnnstate = self.out_rnn(_emb, mstate.rnnstate)
        mstate.rnnstate = new_rnnstate

        alphas, summ, scores = self.att(enc, ctx, ctx_mask)
        mstate.prev_summ = summ
        enc = torch.cat([enc, summ], -1)

        if self.training:
            out_mask = None
        else:
            out_mask = x.get_out_mask(device=enc.device)

        if self.nocopy is True:
            outs = self.out_lin(enc, out_mask)
        else:
            outs = self.out_lin(enc, x.inp_tensor, scores, out_mask=out_mask)
        outs = (outs,) if not q.issequence(outs) else outs
        # _, preds = outs.max(-1)

        if self.store_attn:
            if "stored_attentions" not in x:
                x.stored_attentions = torch.zeros(alphas.size(0), 0, alphas.size(1), device=alphas.device)
            x.stored_attentions = torch.cat([x.stored_attentions, alphas.detach()[:, None, :]], 1)

        mstate.decoding_step = mstate.decoding_step + 1

        return outs[0], x
예제 #5
0
 def on_start_train_epoch():
     penweight.v /= 1.2
     print(q.v(penweight))