Beispiel #1
0
 def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs):
     conf = self.conf
     # scoring
     arc_score, lab_score = self._score(enc_expr,
                                        mask_expr)  # [bs, m, h, *]
     # loss
     bsize, max_len = BK.get_shape(mask_expr)
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in insts])
     # todo(note): here use the original idx of label, no shift!
     gold_labels_arr, _ = self.predict_padder.pad(
         [z.labels.idxes for z in insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [bs, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [bs, Len]
     # collect the losses
     arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)  # [bs, 1]
     arange_m_expr = BK.arange_idx(max_len).unsqueeze(0)  # [1, Len]
     # logsoftmax and losses
     arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1),
                                      -1)  # [bs, m, h]
     lab_logsoftmaxs = BK.log_softmax(lab_score, -1)  # [bs, m, h, Lab]
     arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr,
                                  gold_heads_expr]  # [bs, Len]
     lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr,
                                  gold_heads_expr,
                                  gold_labels_expr]  # [bs, Len]
     # head selection (no root)
     arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum()
     lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum()
     final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum
     final_loss_count = mask_expr[:, 1:].sum()
     return [[final_loss, final_loss_count]]
Beispiel #2
0
 def predict(self, insts: List[ParseInstance], enc_expr, mask_expr,
             **kwargs):
     conf = self.conf
     # scoring
     arc_score, lab_score = self._score(enc_expr,
                                        mask_expr)  # [bs, m, h, *]
     full_score = BK.log_softmax(arc_score, -2) + BK.log_softmax(
         lab_score, -1)  # [bs, m, h, Lab]
     # decode
     mst_lengths = [len(z) + 1
                    for z in insts]  # +1 to include ROOT for mst decoding
     mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32)
     mst_heads_arr, mst_labels_arr, mst_scores_arr = \
         nmst_unproj(full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True)
     # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
     misc_prefix = "g"
     for one_idx, one_inst in enumerate(insts):
         cur_length = mst_lengths[one_idx]
         one_inst.pred_heads.set_vals(
             mst_heads_arr[one_idx]
             [:cur_length])  # directly int-val for heads
         one_inst.pred_labels.build_vals(
             mst_labels_arr[one_idx][:cur_length], self.label_vocab)
         one_scores = mst_scores_arr[one_idx][:cur_length]
         one_inst.pred_par_scores.set_vals(one_scores)
         # extra output
         one_inst.extra_pred_misc[misc_prefix +
                                  "_score"] = one_scores.tolist()
Beispiel #3
0
 def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.):
     conf = self.conf
     bsize = len(ms_items)
     # build targets (include all sents)
     # todo(note): use "x.entity_fillers" for getting gold args
     offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets(
         ms_items, lambda x: x.entity_fillers, True, True,
         conf.train_neg_rate, conf.train_neg_rate_outside, True)
     labels_t.clamp_(max=1)  # either 0 or 1
     # -----
     # return 0 if all no targets
     if BK.get_shape(offsets_t, -1) == 0:
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz]]
     # -----
     arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     sel_bert_t = bert_expr[arange_t, offsets_t]  # [bsize, ?, Fold, D]
     sel_basic_t = None if basic_expr is None else basic_expr[
         arange_t, offsets_t]  # [bsize, ?, D']
     hiddens = self.adp(sel_bert_t, sel_basic_t, [])  # [bsize, ?, D"]
     # build loss
     logits = self.predictor(hiddens)  # [bsize, ?, Out]
     log_probs = BK.log_softmax(logits, -1)
     picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze(
         -1)  # [bsize, ?]
     masked_losses = picked_log_probs * masks_t
     # loss_sum, loss_count, gold_count
     return [[
         masked_losses.sum(),
         masks_t.sum(), (labels_t > 0).float().sum()
     ]]
Beispiel #4
0
 def predict(self,
             repr_ef,
             repr_evt,
             lab_ef,
             lab_evt,
             mask_ef=None,
             mask_evt=None,
             ret_full_logprobs=False):
     # -----
     ret_shape = BK.get_shape(lab_ef)[:-1] + [
         BK.get_shape(lab_ef, -1),
         BK.get_shape(lab_evt, -1)
     ]
     if np.prod(ret_shape) == 0:
         if ret_full_logprobs:
             return BK.zeros(ret_shape + [self.num_label])
         else:
             return BK.zeros(ret_shape), BK.zeros(ret_shape).long()
     # -----
     # todo(note): +1 for space of DROPED(UNK)
     full_score = self._score(repr_ef, repr_evt, lab_ef + 1,
                              lab_evt + 1)  # [*, len-ef, len-evt, D]
     full_logprobs = BK.log_softmax(full_score, -1)
     if ret_full_logprobs:
         return full_logprobs
     else:
         # greedy maximum decode
         ret_logprobs, ret_idxes = full_logprobs.max(
             -1)  # [*, len-ef, len-evt]
         # mask non-valid ones
         if mask_ef is not None:
             ret_idxes *= (mask_ef.unsqueeze(-1)).long()
         if mask_evt is not None:
             ret_idxes *= (mask_evt.unsqueeze(-2)).long()
         return ret_logprobs, ret_idxes
Beispiel #5
0
 def _get_one_loss(self, predictor, hidden_t, labels_t, masks_t,
                   lambda_loss):
     logits = predictor(hidden_t)  # [bsize, ?, Out]
     log_probs = BK.log_softmax(logits, -1)
     picked_neg_log_probs = -BK.gather_one_lastdim(
         log_probs, labels_t).squeeze(-1)  # [bsize, ?]
     masked_losses = picked_neg_log_probs * masks_t
     # loss_sum, loss_count, gold_count(only for type)
     return [
         masked_losses.sum() * lambda_loss,
         masks_t.sum(), (labels_t > 0).float().sum()
     ]
Beispiel #6
0
 def _pred_and_put_res(self, predictor, hidden_t, evt_arr, put_f):
     logits = predictor(hidden_t)  # [bsize, ?, Out]
     log_probs = BK.log_softmax(logits, -1)
     max_log_probs, max_label_idxes = log_probs.max(
         -1)  # [bs, ?], simply argmax prediction
     max_log_probs_arr, max_label_idxes_arr = BK.get_value(
         max_log_probs), BK.get_value(max_label_idxes)
     for evt_row, lprob_row, lidx_row in zip(evt_arr, max_log_probs_arr,
                                             max_label_idxes_arr):
         for one_evt, one_lprob, one_lidx in zip(evt_row, lprob_row,
                                                 lidx_row):
             if one_evt is not None:
                 put_f(one_evt, one_lprob,
                       one_lidx)  # callback for inplace setting
Beispiel #7
0
 def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size):
     conf = self.conf
     free_mode = (force_widx is None)
     prev_state_h = prev_state[0]
     # =====
     # collect att scores
     key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)])  # [*, slen, h]
     query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)])  # [*, R, h]
     orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1))  # [*, slen, R]
     orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN  # [*, slen, R]
     # first maximum across the R dim (this step is hard max)
     maxr_scores, maxr_idxes = orig_scores.max(-1)  # [*, slen]
     if conf.zero_eos_score:
         # use mask to make it able to be backward
         tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.)
         tmp_mask.index_fill_(-1, BK.input_idx(0), 0.)
         maxr_scores *= tmp_mask
     # then select over the slen dim (this step is prob based)
     maxr_logprobs = BK.log_softmax(maxr_scores)  # [*, slen]
     if free_mode:
         cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1))
         sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False)  # [*, beam]
     else:
         sel_tok_idxes = force_widx.unsqueeze(-1)  # [*, 1]
         sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes)  # [*, 1]
     # then collect the info and perform labeling
     lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2)  # [*, ?, ~]
     lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1)  # [*, ?, 1]
     lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)]  # [*, ?, ~]  # todo(+3): using soft version?
     lf_prev_state = prev_state_h.unsqueeze(-2)  # [*, 1, ~]
     lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state])  # [*, ?, ~]
     # final predicting labels
     # todo(+N): here we select only max at labeling part, only beam at previous one
     if free_mode:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None)  # [*, ?]
     else:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1))
     # no lab-logprob (*=0) for eos (sel_tok==0)
     sel_lab_logprobs *= (sel_tok_idxes>0).float()
     # compute next-state [*, ?, ~]
     # todo(note): here we flatten the first two dims
     tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1]
     tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1)
     tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1))
     tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1))
                       for z in prev_state]  # [*, ?, ?, D]
     next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None)
     next_state = [z.view(tmp_rnn_dims) for z in next_state]
     return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
Beispiel #8
0
 def loss(self,
          repr_ef,
          repr_evt,
          lab_ef,
          lab_evt,
          mask_ef,
          mask_evt,
          gold_idxes,
          margin=0.):
     conf = self.conf
     # -----
     if np.prod(BK.get_shape(gold_idxes)) == 0:
         return [[BK.zeros([]), BK.zeros([])]]
     # -----
     # todo(note): +1 for space of DROPED(UNK)
     lab_ef = self._dropout_idxes(lab_ef + 1, conf.train_drop_ef_lab)
     lab_evt = self._dropout_idxes(lab_evt + 1, conf.train_drop_evt_lab)
     if conf.linker_ef_detach:
         repr_ef = repr_ef.detach()
     if conf.linker_evt_detach:
         repr_evt = repr_evt.detach()
     full_score = self._score(repr_ef, repr_evt, lab_ef,
                              lab_evt)  # [*, len-ef, len-evt, D]
     if margin > 0.:
         aug_score = BK.zeros(BK.get_shape(full_score)) + margin
         aug_score.scatter_(-1, gold_idxes.unsqueeze(-1), 0.)
         full_score += aug_score
     full_logprobs = BK.log_softmax(full_score, -1)
     gold_logprobs = full_logprobs.gather(-1,
                                          gold_idxes.unsqueeze(-1)).squeeze(
                                              -1)  # [*, len-ef, len-evt]
     # sampling and mask
     loss_mask = mask_ef.unsqueeze(-1) * mask_evt.unsqueeze(-2)
     # ====
     # first select examples (randomly)
     sel_mask = (BK.rand(BK.get_shape(loss_mask)) <
                 conf.train_min_rate).float()  # [*, len-ef, len-evt]
     # add gold and exclude pad
     sel_mask += (gold_idxes > 0).float()
     sel_mask.clamp_(max=1.)
     loss_mask *= sel_mask
     # =====
     loss_sum = -(gold_logprobs * loss_mask).sum()
     loss_count = loss_mask.sum()
     ret_losses = [[loss_sum, loss_count]]
     return ret_losses
Beispiel #9
0
 def _predict(self, all_scores, force_idxes):
     # predicting on the last one
     last_score = BK.log_softmax(all_scores[self.eff_max_layer - 1],
                                 -1)  # [*, ?]
     if force_idxes is None:
         # todo(note): currently only do max
         res_logprobs, res_idxes = last_score.max(-1)  # [*]
     else:
         res_idxes = force_idxes
         res_logprobs = last_score.gather(-1,
                                          res_idxes.unsqueeze(-1)).squeeze(
                                              -1)  # [*]
     # lookup: [*, D]
     conf = self.conf
     if conf.use_lookup_soft:
         ret_lab_embeds = self.lookup_soft(all_scores)
     else:
         ret_lab_embeds = self.lookup(res_idxes)
     return res_logprobs, res_idxes, ret_lab_embeds
Beispiel #10
0
 def loss(self, repr_t, attn_t, mask_t, disturb_keep_arr, **kwargs):
     conf = self.conf
     CR, PR = conf.cand_range, conf.pred_range
     # -----
     mask_single = BK.copy(mask_t)
     # no predictions for ARTI_ROOT
     if self.add_root_token:
         mask_single[:, 0] = 0.  # [bs, slen]
     # casting predicting range
     cur_slen = BK.get_shape(mask_single, -1)
     arange_t = BK.arange_idx(cur_slen)  # [slen]
     # [1, len] - [len, 1] = [len, len]
     reldist_t = (arange_t.unsqueeze(-2) - arange_t.unsqueeze(-1)
                  )  # [slen, slen]
     mask_pair = ((reldist_t.abs() <= CR) &
                  (reldist_t != 0)).float()  # within CR-range; [slen, slen]
     mask_pair = mask_pair * mask_single.unsqueeze(
         -1) * mask_single.unsqueeze(-2)  # [bs, slen, slen]
     if disturb_keep_arr is not None:
         mask_pair *= BK.input_real(1. - disturb_keep_arr).unsqueeze(
             -1)  # no predictions for the kept ones!
     # get all pair scores
     score_t = self.ps_node.paired_score(
         repr_t, repr_t, attn_t, maskp=mask_pair)  # [bs, len_q, len_k, 2*R]
     # -----
     # loss: normalize on which dim?
     # get the answers first
     if conf.pred_abs:
         answer_t = reldist_t.abs()  # [1,2,3,...,PR]
         answer_t.clamp_(
             min=0, max=PR -
             1)  # [slen, slen], clip in range, distinguish using masks
     else:
         answer_t = BK.where(
             (reldist_t >= 0), reldist_t - 1,
             reldist_t + 2 * PR)  # [1,2,3,...PR,-PR,-PR+1,...,-1]
         answer_t.clamp_(
             min=0, max=2 * PR -
             1)  # [slen, slen], clip in range, distinguish using masks
     # expand answer into idxes
     answer_hit_t = BK.zeros(BK.get_shape(answer_t) +
                             [2 * PR])  # [len_q, len_k, 2*R]
     answer_hit_t.scatter_(-1, answer_t.unsqueeze(-1), 1.)
     answer_valid_t = ((reldist_t.abs() <= PR) &
                       (reldist_t != 0)).float().unsqueeze(
                           -1)  # [bs, len_q, len_k, 1]
     answer_hit_t = answer_hit_t * mask_pair.unsqueeze(
         -1) * answer_valid_t  # clear invalid ones; [bs, len_q, len_k, 2*R]
     # get losses sum(log(answer*prob))
     # -- dim=-1 is standard 2*PR classification, dim=-2 usually have 2*PR candidates, but can be less at edges
     all_losses = []
     for one_dim, one_lambda in zip([-1, -2],
                                    [conf.lambda_n1, conf.lambda_n2]):
         if one_lambda > 0.:
             # since currently there can be only one or zero correct answer
             logprob_t = BK.log_softmax(score_t,
                                        one_dim)  # [bs, len_q, len_k, 2*R]
             sumlogprob_t = (logprob_t * answer_hit_t).sum(
                 one_dim)  # [bs, len_q, len_k||2*R]
             cur_dim_mask_t = (answer_hit_t.sum(one_dim) >
                               0.).float()  # [bs, len_q, len_k||2*R]
             # loss
             cur_dim_loss = -(sumlogprob_t * cur_dim_mask_t).sum()
             cur_dim_count = cur_dim_mask_t.sum()
             # argmax and corr (any correct counts)
             _, cur_argmax_idxes = score_t.max(one_dim)
             cur_corrs = answer_hit_t.gather(
                 one_dim, cur_argmax_idxes.unsqueeze(
                     one_dim))  # [bs, len_q, len_k|1, 2*R|1]
             cur_dim_corr_count = cur_corrs.sum()
             # compile loss
             one_loss = LossHelper.compile_leaf_info(
                 f"d{one_dim}",
                 cur_dim_loss,
                 cur_dim_count,
                 loss_lambda=one_lambda,
                 corr=cur_dim_corr_count)
             all_losses.append(one_loss)
     return self._compile_component_loss("orp", all_losses)
Beispiel #11
0
 def fb_on_batch(self,
                 annotated_insts,
                 training=True,
                 loss_factor=1,
                 **kwargs):
     self.refresh_batch(training)
     margin = self.margin.value
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in annotated_insts])
     gold_labels_arr, _ = self.predict_padder.pad(
         [self.real2pred_labels(z.labels.idxes) for z in annotated_insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
     # ===== calculate
     scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
         annotated_insts, training)
     full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                           training, margin,
                                           gold_heads_expr)
     #
     final_losses = None
     if self.norm_local or self.norm_single:
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         # already added margin previously
         losses_heads = losses_labels = None
         if self.loss_prob:
             if self.norm_local:
                 losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr)
                 losses_labels = BK.loss_nll(select_label_score,
                                             gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=False)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=False)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_hinge:
             if self.norm_local:
                 losses_heads = BK.loss_hinge(full_arc_score,
                                              gold_heads_expr)
                 losses_labels = BK.loss_hinge(select_label_score,
                                               gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=True,
                                                    margin=margin)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=True,
                                                     margin=margin)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_mr:
             # special treatment!
             probs_heads = BK.softmax(full_arc_score, dim=-1)  # [bs, m, h]
             probs_labels = BK.softmax(select_label_score,
                                       dim=-1)  # [bs, m, h]
             # select
             probs_head_gold = BK.gather_one_lastdim(
                 probs_heads, gold_heads_expr).squeeze(-1)  # [bs, m]
             probs_label_gold = BK.gather_one_lastdim(
                 probs_labels, gold_labels_expr).squeeze(-1)  # [bs, m]
             # root and pad will be excluded later
             # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions
             # todo(warn): have problem since steps will be quite small, not used!
             final_losses = (mask_expr - probs_head_gold * probs_label_gold
                             )  # let loss>=0
     elif self.norm_global:
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, training,
                                                   margin, gold_heads_expr,
                                                   gold_labels_expr)
         # for this one, use the merged full score
         full_score = full_arc_score.unsqueeze(
             -1) + full_label_score  # [BS, m, h, L]
         # +=1 to include ROOT for mst decoding
         mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                      dtype=np.int32)
         # do inference
         if self.loss_prob:
             marginals_expr = self._marginal(
                 full_score, mask_expr, mst_lengths_arr)  # [BS, m, h, L]
             final_losses = self._losses_global_prob(
                 full_score, gold_heads_expr, gold_labels_expr,
                 marginals_expr, mask_expr)
             if self.alg_proj:
                 # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg),
                 #  but this might be too loose, although the unproj edges are few?
                 gold_unproj_arr, _ = self.predict_padder.pad(
                     [z.unprojs for z in annotated_insts])
                 gold_unproj_expr = BK.input_real(
                     gold_unproj_arr)  # [BS, Len]
                 comparing_expr = Constants.REAL_PRAC_MIN * (
                     1. - gold_unproj_expr)
                 final_losses = BK.max_elem(final_losses, comparing_expr)
         elif self.loss_hinge:
             pred_heads_arr, pred_labels_arr, _ = self._decode(
                 full_score, mask_expr, mst_lengths_arr)
             pred_heads_expr = BK.input_idx(pred_heads_arr)  # [BS, Len]
             pred_labels_expr = BK.input_idx(pred_labels_arr)  # [BS, Len]
             #
             final_losses = self._losses_global_hinge(
                 full_score, gold_heads_expr, gold_labels_expr,
                 pred_heads_expr, pred_labels_expr, mask_expr)
         elif self.loss_mr:
             # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges
             raise NotImplementedError(
                 "Not implemented for global-loss + mr.")
     elif self.norm_hlocal:
         # firstly label losses are the same
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         losses_labels = BK.loss_nll(select_label_score, gold_labels_expr)
         # then specially for arc loss
         children_masks_arr, _ = self.hlocal_padder.pad(
             [z.get_children_mask_arr() for z in annotated_insts])
         children_masks_expr = BK.input_real(
             children_masks_arr)  # [bs, h, m]
         # [bs, h]
         # todo(warn): use prod rather than sum, but still only an approximation for the top-down
         # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr))
         losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose(
             -1, -2) * children_masks_expr,
                              dim=-1)
         # including the root-head is important
         losses_arc[:, 1] += losses_arc[:, 0]
         final_losses = losses_arc + losses_labels
     #
     # jpos loss? (the same mask as parsing)
     jpos_losses_expr = jpos_pack[1]
     if jpos_losses_expr is not None:
         final_losses += jpos_losses_expr
     # collect loss with mask, also excluding the first symbol of ROOT
     final_losses_masked = (final_losses * mask_expr)[:, 1:]
     final_loss_sum = BK.sum(final_losses_masked)
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_tok = sum(len(z) for z in annotated_insts)
     if self.conf.tconf.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     #
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     if training:
         BK.backward(final_loss, loss_factor)
     return info
Beispiel #12
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     with BK.no_grad_env():
         self.refresh_batch(False)
         # ===== calculate
         scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
             insts, False)
         full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                               False, 0.)
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, False, 0.)
         # normalizing scores
         full_score = None
         final_exp_score = False  # whether to provide PROB by exp
         if self.norm_local and self.loss_prob:
             full_score = BK.log_softmax(full_arc_score,
                                         -1).unsqueeze(-1) + BK.log_softmax(
                                             full_label_score, -1)
             final_exp_score = True
         elif self.norm_hlocal and self.loss_prob:
             # normalize at m dimension, ignore each nodes's self-finish step.
             full_score = BK.log_softmax(full_arc_score,
                                         -2).unsqueeze(-1) + BK.log_softmax(
                                             full_label_score, -1)
         elif self.norm_single and self.loss_prob:
             if self.conf.iconf.dec_single_neg:
                 # todo(+2): add all-neg for prob explanation
                 full_arc_probs = BK.sigmoid(full_arc_score)
                 full_label_probs = BK.sigmoid(full_label_score)
                 fake_arc_scores = BK.log(full_arc_probs) - BK.log(
                     1. - full_arc_probs)
                 fake_label_scores = BK.log(full_label_probs) - BK.log(
                     1. - full_label_probs)
                 full_score = fake_arc_scores.unsqueeze(
                     -1) + fake_label_scores
             else:
                 full_score = BK.logsigmoid(full_arc_score).unsqueeze(
                     -1) + BK.logsigmoid(full_label_score)
                 final_exp_score = True
         else:
             full_score = full_arc_score.unsqueeze(-1) + full_label_score
         # decode
         mst_lengths = [len(z) + 1 for z in insts
                        ]  # +=1 to include ROOT for mst decoding
         mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode(
             full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32))
         if final_exp_score:
             mst_scores_arr = np.exp(mst_scores_arr)
         # jpos prediction (directly index, no converting as in parsing)
         jpos_preds_expr = jpos_pack[2]
         has_jpos_pred = jpos_preds_expr is not None
         jpos_preds_arr = BK.get_value(
             jpos_preds_expr) if has_jpos_pred else None
         # ===== assign
         info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)}
         mst_real_labels = self.pred2real_labels(mst_labels_arr)
         for one_idx, one_inst in enumerate(insts):
             cur_length = mst_lengths[one_idx]
             one_inst.pred_heads.set_vals(
                 mst_heads_arr[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 mst_real_labels[one_idx][:cur_length], self.label_vocab)
             one_inst.pred_par_scores.set_vals(
                 mst_scores_arr[one_idx][:cur_length])
             if has_jpos_pred:
                 one_inst.pred_poses.build_vals(
                     jpos_preds_arr[one_idx][:cur_length],
                     self.bter.pos_vocab)
         return info
Beispiel #13
0
 def predict(self, ms_items: List, bert_expr, basic_expr):
     conf = self.conf
     bsize = len(ms_items)
     # build targets (include all sents)
     offsets_t, masks_t, _, _, _ = PrepHelper.prep_targets(
         ms_items, lambda x: [], True, True, 1., 1., False)
     arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     sel_bert_t = bert_expr[arange_t, offsets_t]  # [bsize, ?, Fold, D]
     sel_basic_t = None if basic_expr is None else basic_expr[
         arange_t, offsets_t]  # [bsize, ?, D']
     hiddens = self.adp(sel_bert_t, sel_basic_t, [])  # [bsize, ?, D"]
     logits = self.predictor(hiddens)  # [bsize, ?, Out]
     # -----
     log_probs = BK.log_softmax(logits, -1)
     log_probs[:, :, 0] -= conf.nil_penalty  # encourage more predictions
     topk_log_probs, topk_log_labels = log_probs.max(
         dim=-1)  # [bsize, ?, k]
     # decoding
     head_offsets_arr = BK.get_value(offsets_t)  # [bs, ?]
     masks_arr = BK.get_value(masks_t)
     topk_log_probs_arr, topk_log_labels_arr = BK.get_value(
         topk_log_probs), BK.get_value(topk_log_labels)  # [bsize, ?, k]
     for one_ms_item, one_offsets_arr, one_masks_arr, one_logprobs_arr, one_labels_arr \
             in zip(ms_items, head_offsets_arr, masks_arr, topk_log_probs_arr, topk_log_labels_arr):
         # build tidx2sidx
         one_sents = one_ms_item.sents
         one_offsets = one_ms_item.offsets
         tidx2sidx = []
         for idx in range(1, len(one_offsets)):
             tidx2sidx.extend([idx - 1] *
                              (one_offsets[idx] - one_offsets[idx - 1]))
         # get all candidates
         all_candidates = [[] for _ in one_sents]
         for cur_offset, cur_valid, cur_logprob, cur_label in zip(
                 one_offsets_arr, one_masks_arr, one_logprobs_arr,
                 one_labels_arr):
             if not cur_valid or cur_label <= 0:
                 continue
             # which sent
             cur_offset = int(cur_offset)
             cur_sidx = tidx2sidx[cur_offset]
             cur_sent = one_sents[cur_sidx]
             minus_offset = one_ms_item.offsets[
                 cur_sidx] - 1  # again consider the ROOT
             cur_mention = Mention(
                 HardSpan(cur_sent.sid, cur_offset - minus_offset, None,
                          None))
             all_candidates[cur_sidx].append(
                 (cur_sent, cur_mention, cur_label, cur_logprob))
         # keep certain ratio for each sent separately?
         final_candidates = []
         if conf.pred_sent_ratio_sep:
             for one_sent, one_sent_candidates in zip(
                     one_sents, all_candidates):
                 cur_keep_num = max(
                     int(conf.pred_sent_ratio * (one_sent.length - 1)), 1)
                 one_sent_candidates.sort(key=lambda x: x[-1], reverse=True)
                 final_candidates.extend(one_sent_candidates[:cur_keep_num])
         else:
             all_size = 0
             for one_sent, one_sent_candidates in zip(
                     one_sents, all_candidates):
                 all_size += one_sent.length - 1
                 final_candidates.extend(one_sent_candidates)
             final_candidates.sort(key=lambda x: x[-1], reverse=True)
             final_keep_num = max(int(conf.pred_sent_ratio * all_size),
                                  len(one_sents))
             final_candidates = final_candidates[:final_keep_num]
         # add them all
         for cur_sent, cur_mention, cur_label, cur_logprob in final_candidates:
             cur_logprob = float(cur_logprob)
             doc_id = cur_sent.doc.doc_id
             self.id_counter[doc_id] += 1
             new_id = f"ef-{doc_id}-{self.id_counter[doc_id]}"
             hlidx = self.valid_hlidx
             new_ef = EntityFiller(new_id,
                                   cur_mention,
                                   str(hlidx),
                                   None,
                                   True,
                                   type_idx=hlidx,
                                   score=cur_logprob)
             cur_sent.pred_entity_fillers.append(new_ef)
Beispiel #14
0
 def loss(self,
          repr_ts,
          input_erase_mask_arr,
          orig_map: Dict,
          active_hid=True,
          **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # prepare idxes for the masked ones
     if self.add_root_token:  # offset for the special root added in embedder
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr),
             padding_idx=-1)  # [bsize, ?]
         repr_mask_idxes = mask_idxes + 1
         mask_idxes.clamp_(min=0)
     else:
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr))  # [bsize, ?]
         repr_mask_idxes = mask_idxes
     # get the losses
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         return self._compile_component_loss("mlm", [])
     else:
         if not isinstance(repr_ts, (List, Tuple)):
             repr_ts = [repr_ts]
         target_word_scores, target_pos_scores = [], []
         target_pos_scores = None  # todo(+N): for simplicity, currently ignore this one!!
         for layer_idx in conf.loss_layers:
             # calculate scores
             target_reprs = BK.gather_first_dims(repr_ts[layer_idx],
                                                 repr_mask_idxes,
                                                 1)  # [bsize, ?, *]
             if self.hid_layer and active_hid:  # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside!
                 target_hids = self.hid_layer(target_reprs)
             else:
                 target_hids = target_reprs
             if _tie_input_embeddings:
                 pred_W = self.inputter_word_node.E.E[:self.
                                                      pred_word_size]  # [PSize, Dim]
                 target_word_scores.append(BK.matmul(
                     target_hids, pred_W.T))  # List[bsize, ?, Vw]
             else:
                 target_word_scores.append(self.pred_word_layer(
                     target_hids))  # List[bsize, ?, Vw]
         # gather the losses
         all_losses = []
         for pred_name, target_scores, loss_lambda, range_min, range_max in \
                 zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos],
                     [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]):
             if loss_lambda > 0.:
                 seq_idx_t = BK.input_idx(
                     orig_map[pred_name])  # [bsize, slen]
                 target_idx_t = seq_idx_t.gather(-1,
                                                 mask_idxes)  # [bsize, ?]
                 ranged_mask_valids = mask_valids * (
                     target_idx_t >= range_min).float() * (
                         target_idx_t <= range_max).float()
                 target_idx_t[(ranged_mask_valids <
                               1.)] = 0  # make sure invalid ones in range
                 # calculate for each layer
                 all_layer_losses, all_layer_scores = [], []
                 for one_layer_idx, one_target_scores in enumerate(
                         target_scores):
                     # get loss: [bsize, ?]
                     one_pred_losses = BK.loss_nll(
                         one_target_scores,
                         target_idx_t) * conf.loss_weights[one_layer_idx]
                     all_layer_losses.append(one_pred_losses)
                     # get scores
                     one_pred_scores = BK.log_softmax(
                         one_target_scores,
                         -1) * conf.loss_weights[one_layer_idx]
                     all_layer_scores.append(one_pred_scores)
                 # combine all layers
                 pred_losses = self.loss_comb_f(all_layer_losses)
                 pred_loss_sum = (pred_losses * ranged_mask_valids).sum()
                 pred_loss_count = ranged_mask_valids.sum()
                 # argmax
                 _, argmax_idxes = self.score_comb_f(all_layer_scores).max(
                     -1)
                 pred_corrs = (argmax_idxes
                               == target_idx_t).float() * ranged_mask_valids
                 pred_corr_count = pred_corrs.sum()
                 # compile leaf loss
                 r_loss = LossHelper.compile_leaf_info(
                     pred_name,
                     pred_loss_sum,
                     pred_loss_count,
                     loss_lambda=loss_lambda,
                     corr=pred_corr_count)
                 all_losses.append(r_loss)
         return self._compile_component_loss("mlm", all_losses)