Beispiel #1
0
 def _get_rel_dist(self, len_q: int, len_k: int = None):
     if len_k is None:
         len_k = len_q
     dist_x = BK.arange_idx(0, len_k).unsqueeze(0)  # [1, len_k]
     dist_y = BK.arange_idx(0, len_q).unsqueeze(1)  # [len_q, 1]
     distance = dist_x - dist_y  # [len_q, len_k]
     return distance
Beispiel #2
0
 def loss(self, ms_items: List, bert_expr):
     conf = self.conf
     max_range = self.conf.max_range
     bsize = len(ms_items)
     # collect instances
     col_efs, _, col_bidxes_t, col_hidxes_t, col_ldists_t, col_rdists_t = self._collect_insts(
         ms_items, True)
     if len(col_efs) == 0:
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz], [zzz, zzz, zzz]]
     left_scores, right_scores = self._score(bert_expr, col_bidxes_t,
                                             col_hidxes_t)  # [N, R]
     if conf.use_binary_scorer:
         left_binaries, right_binaries = (BK.arange_idx(max_range)<=col_ldists_t.unsqueeze(-1)).float(), \
                                         (BK.arange_idx(max_range)<=col_rdists_t.unsqueeze(-1)).float()  # [N,R]
         left_losses = BK.binary_cross_entropy_with_logits(
             left_scores, left_binaries, reduction='none')[:, 1:]
         right_losses = BK.binary_cross_entropy_with_logits(
             right_scores, right_binaries, reduction='none')[:, 1:]
         left_count = right_count = BK.input_real(
             BK.get_shape(left_losses, 0) * (max_range - 1))
     else:
         left_losses = BK.loss_nll(left_scores, col_ldists_t)
         right_losses = BK.loss_nll(right_scores, col_rdists_t)
         left_count = right_count = BK.input_real(
             BK.get_shape(left_losses, 0))
     return [[left_losses.sum(), left_count, left_count],
             [right_losses.sum(), right_count, right_count]]
Beispiel #3
0
 def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs):
     conf = self.conf
     # scoring
     arc_score, lab_score = self._score(enc_expr,
                                        mask_expr)  # [bs, m, h, *]
     # loss
     bsize, max_len = BK.get_shape(mask_expr)
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in insts])
     # todo(note): here use the original idx of label, no shift!
     gold_labels_arr, _ = self.predict_padder.pad(
         [z.labels.idxes for z in insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [bs, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [bs, Len]
     # collect the losses
     arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)  # [bs, 1]
     arange_m_expr = BK.arange_idx(max_len).unsqueeze(0)  # [1, Len]
     # logsoftmax and losses
     arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1),
                                      -1)  # [bs, m, h]
     lab_logsoftmaxs = BK.log_softmax(lab_score, -1)  # [bs, m, h, Lab]
     arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr,
                                  gold_heads_expr]  # [bs, Len]
     lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr,
                                  gold_heads_expr,
                                  gold_labels_expr]  # [bs, Len]
     # head selection (no root)
     arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum()
     lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum()
     final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum
     final_loss_count = mask_expr[:, 1:].sum()
     return [[final_loss, final_loss_count]]
Beispiel #4
0
 def _enc(self, input_lexi, input_expr, input_mask, sel_idxes):
     if self.dmxnn:
         bsize, slen = BK.get_shape(input_mask)
         if sel_idxes is None:
             sel_idxes = BK.arange_idx(slen).unsqueeze(
                 0)  # select all, [1, slen]
         ncand = BK.get_shape(sel_idxes, -1)
         # enc_expr aug with PE
         rel_dist = BK.arange_idx(slen).unsqueeze(0).unsqueeze(
             0) - sel_idxes.unsqueeze(-1)  # [*, ?, slen]
         pe_embeds = self.posi_embed(rel_dist)  # [*, ?, slen, Dpe]
         aug_enc_expr = BK.concat([
             pe_embeds.expand(bsize, -1, -1, -1),
             input_expr.unsqueeze(1).expand(-1, ncand, -1, -1)
         ], -1)  # [*, ?, slen, D+Dpe]
         # [*, ?, slen, Denc]
         hidden_expr = self.e_encoder(
             aug_enc_expr.view(bsize * ncand, slen, -1),
             input_mask.unsqueeze(1).expand(-1, ncand,
                                            -1).contiguous().view(
                                                bsize * ncand, slen))
         hidden_expr = hidden_expr.view(bsize, ncand, slen, -1)
         # dynamic max-pooling (dist<0, dist=0, dist>0)
         NEG = Constants.REAL_PRAC_MIN
         mp_hiddens = []
         mp_masks = [rel_dist < 0, rel_dist == 0, rel_dist > 0]
         for mp_mask in mp_masks:
             float_mask = mp_mask.float() * input_mask.unsqueeze(
                 -2)  # [*, ?, slen]
             valid_mask = (float_mask.sum(-1) > 0.).float().unsqueeze(
                 -1)  # [*, ?, 1]
             mask_neg_val = (
                 1. - float_mask).unsqueeze(-1) * NEG  # [*, ?, slen, 1]
             # todo(+2): or do we simply multiply mask?
             mp_hid0 = (hidden_expr + mask_neg_val).max(-2)[0]
             mp_hid = mp_hid0 * valid_mask  # [*, ?, Denc]
             mp_hiddens.append(self.special_drop(mp_hid))
             # mp_hiddens.append(mp_hid)
         final_hiddens = mp_hiddens
     else:
         hidden_expr = self.e_encoder(input_expr,
                                      input_mask)  # [*, slen, D']
         if sel_idxes is None:
             hidden_expr1 = hidden_expr
         else:
             hidden_expr1 = BK.gather_first_dims(hidden_expr, sel_idxes,
                                                 -2)  # [*, ?, D']
         final_hiddens = [self.special_drop(hidden_expr1)]
     if self.lab_f_use_lexi:
         final_hiddens.append(
             BK.gather_first_dims(input_lexi, sel_idxes,
                                  -2))  # [*, ?, DLex]
     ret_expr = self.lab_f(final_hiddens)  # [*, ?, DLab]
     return ret_expr
Beispiel #5
0
 def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.):
     conf = self.conf
     bsize = len(ms_items)
     # build targets (include all sents)
     # todo(note): use "x.entity_fillers" for getting gold args
     offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets(
         ms_items, lambda x: x.entity_fillers, True, True,
         conf.train_neg_rate, conf.train_neg_rate_outside, True)
     labels_t.clamp_(max=1)  # either 0 or 1
     # -----
     # return 0 if all no targets
     if BK.get_shape(offsets_t, -1) == 0:
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz]]
     # -----
     arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     sel_bert_t = bert_expr[arange_t, offsets_t]  # [bsize, ?, Fold, D]
     sel_basic_t = None if basic_expr is None else basic_expr[
         arange_t, offsets_t]  # [bsize, ?, D']
     hiddens = self.adp(sel_bert_t, sel_basic_t, [])  # [bsize, ?, D"]
     # build loss
     logits = self.predictor(hiddens)  # [bsize, ?, Out]
     log_probs = BK.log_softmax(logits, -1)
     picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze(
         -1)  # [bsize, ?]
     masked_losses = picked_log_probs * masks_t
     # loss_sum, loss_count, gold_count
     return [[
         masked_losses.sum(),
         masks_t.sum(), (labels_t > 0).float().sum()
     ]]
Beispiel #6
0
 def __call__(self,
              word_arr: np.ndarray = None,
              char_arr: np.ndarray = None,
              extra_arrs: Iterable[np.ndarray] = (),
              aux_arrs: Iterable[np.ndarray] = ()):
     exprs = []
     # word/char/extras/posi
     seq_shape = None
     if self.has_word:
         # todo(warn): singleton-UNK-dropout should be done outside before
         seq_shape = word_arr.shape
         word_expr = self.dropmd_word(self.word_embed(word_arr))
         exprs.append(word_expr)
     if self.has_char:
         seq_shape = char_arr.shape[:-1]
         char_embeds = self.char_embed(
             char_arr)  # [*, seq-len, word-len, D]
         char_cat_expr = self.dropmd_char(
             BK.concat([z(char_embeds) for z in self.char_cnns]))
         exprs.append(char_cat_expr)
     zcheck(
         len(extra_arrs) == len(self.extra_embeds),
         "Unmatched extra fields.")
     for one_extra_arr, one_extra_embed, one_extra_dropmd in zip(
             extra_arrs, self.extra_embeds, self.dropmd_extras):
         seq_shape = one_extra_arr.shape
         exprs.append(one_extra_dropmd(one_extra_embed(one_extra_arr)))
     if self.has_posi:
         seq_len = seq_shape[-1]
         posi_idxes = BK.arange_idx(seq_len)
         posi_input0 = self.posi_embed(posi_idxes)
         for _ in range(len(seq_shape) - 1):
             posi_input0 = BK.unsqueeze(posi_input0, 0)
         posi_input1 = BK.expand(posi_input0, tuple(seq_shape) + (-1, ))
         exprs.append(posi_input1)
     #
     assert len(aux_arrs) == len(self.drop_auxes)
     for one_aux_arr, one_aux_dim, one_aux_drop, one_fold, one_gamma, one_lambdas in \
             zip(aux_arrs, self.dim_auxes, self.drop_auxes, self.fold_auxes, self.aux_overall_gammas, self.aux_fold_lambdas):
         # fold and apply trainable lambdas
         input_aux_repr = BK.input_real(one_aux_arr)
         input_shape = BK.get_shape(input_aux_repr)
         # todo(note): assume the original concat is [fold/layer, D]
         reshaped_aux_repr = input_aux_repr.view(
             input_shape[:-1] +
             [one_fold, one_aux_dim])  # [*, slen, fold, D]
         lambdas_softmax = BK.softmax(one_gamma,
                                      -1).unsqueeze(-1)  # [fold, 1]
         weighted_aux_repr = (reshaped_aux_repr * lambdas_softmax
                              ).sum(-2) * one_gamma  # [*, slen, D]
         one_aux_expr = one_aux_drop(weighted_aux_repr)
         exprs.append(one_aux_expr)
     #
     concated_exprs = BK.concat(exprs, dim=-1)
     # optional proj
     if self.has_proj:
         final_expr = self.final_layer(concated_exprs)
     else:
         final_expr = concated_exprs
     return final_expr
Beispiel #7
0
 def loss(self, ms_items: List, bert_expr, basic_expr):
     conf = self.conf
     bsize = len(ms_items)
     # use gold targets: only use positive samples!!
     offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets(
         ms_items, lambda x: x.events, True, False, 0., 0., True)  # [bs, ?]
     realis_flist = [(-1 if
                      (z is None or z.realis_idx is None) else z.realis_idx)
                     for z in items_arr.flatten()]
     realis_t = BK.input_idx(realis_flist).view(items_arr.shape)  # [bs, ?]
     realis_mask = (realis_t >= 0).float()
     realis_t.clamp_(min=0)  # make sure all idxes are legal
     # -----
     # return 0 if all no targets
     if BK.get_shape(offsets_t, -1) == 0:
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz], [zzz, zzz, zzz]]  # realis, types
     # -----
     arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     sel_bert_t = bert_expr[arange_t, offsets_t]  # [bsize, ?, Fold, D]
     sel_basic_t = None if basic_expr is None else basic_expr[
         arange_t, offsets_t]  # [bsize, ?, D']
     hiddens = self.adp(sel_bert_t, sel_basic_t, [])  # [bsize, ?, D"]
     # build losses
     loss_item_realis = self._get_one_loss(self.realis_predictor, hiddens,
                                           realis_t, realis_mask,
                                           conf.lambda_realis)
     loss_item_type = self._get_one_loss(self.type_predictor, hiddens,
                                         labels_t, masks_t,
                                         conf.lambda_type)
     return [loss_item_realis, loss_item_type]
Beispiel #8
0
 def _decode(self, insts: List[ParseInstance], full_score, mask_expr,
             misc_prefix):
     # decode
     mst_lengths = [len(z) + 1
                    for z in insts]  # +=1 to include ROOT for mst decoding
     mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32)
     mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj(
         full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True)
     if self.conf.iconf.output_marginals:
         # todo(note): here, we care about marginals for arc
         # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True)
         arc_marginals = nmarginal_unproj(full_score,
                                          mask_expr,
                                          None,
                                          labeled=True).sum(-1)
         bsize, max_len = BK.get_shape(mask_expr)
         idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)
         idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0)
         output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr,
                                     BK.input_idx(mst_heads_arr)]
         mst_marg_arr = BK.get_value(output_marg)
     else:
         mst_marg_arr = None
     # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
     for one_idx, one_inst in enumerate(insts):
         cur_length = mst_lengths[one_idx]
         one_inst.pred_heads.set_vals(
             mst_heads_arr[one_idx]
             [:cur_length])  # directly int-val for heads
         one_inst.pred_labels.build_vals(
             mst_labels_arr[one_idx][:cur_length], self.label_vocab)
         one_scores = mst_scores_arr[one_idx][:cur_length]
         one_inst.pred_par_scores.set_vals(one_scores)
         # extra output
         one_inst.extra_pred_misc[misc_prefix +
                                  "_score"] = one_scores.tolist()
         if mst_marg_arr is not None:
             one_inst.extra_pred_misc[
                 misc_prefix +
                 "_marg"] = mst_marg_arr[one_idx][:cur_length].tolist()
Beispiel #9
0
 def _score(self, bert_expr, bidxes_t, hidxes_t):
     # ----
     # # debug
     # print(f"# ====\n Debug: {ArgSpanExpander._debug_count}")
     # ArgSpanExpander._debug_count += 1
     # ----
     bert_expr = bert_expr.view(BK.get_shape(bert_expr)[:-2] +
                                [-1])  # flatten
     #
     max_range = self.conf.max_range
     max_slen = BK.get_shape(bert_expr, 1)
     # get candidates
     range_t = BK.arange_idx(max_range).unsqueeze(0)  # [1, R]
     bidxes_t = bidxes_t.unsqueeze(1)  # [N, 1]
     hidxes_t = hidxes_t.unsqueeze(1)  # [N, 1]
     left_cands = hidxes_t - range_t  # [N, R]
     right_cands = hidxes_t + range_t
     left_masks = (left_cands >= 0).float()
     right_masks = (right_cands < max_slen).float()
     left_cands.clamp_(min=0)
     right_cands.clamp_(max=max_slen - 1)
     # score
     head_exprs = bert_expr[bidxes_t, hidxes_t]  # [N, 1, D']
     left_cand_exprs = bert_expr[bidxes_t, left_cands]  # [N, R, D']
     right_cand_exprs = bert_expr[bidxes_t, right_cands]
     # actual scoring
     if self.use_lstm_scorer:
         batch_size = BK.get_shape(bidxes_t, 0)
         all_concat_outputs = []
         for cand_exprs, lstm_node in zip(
             [left_cand_exprs, right_cand_exprs], [self.llstm, self.rlstm]):
             cur_state = lstm_node.zero_init_hidden(batch_size)
             step_size = BK.get_shape(cand_exprs, 1)
             all_outputs = []
             for step_i in range(step_size):
                 cur_state = lstm_node(cand_exprs[:, step_i], cur_state,
                                       None)
                 all_outputs.append(cur_state[0])  # using h
             concat_output = BK.stack(all_outputs, 1)  # [N, R, ?]
             all_concat_outputs.append(concat_output)
         left_hidden, right_hidden = all_concat_outputs
         left_scores = self.lscorer(left_hidden).squeeze(-1)  # [N, R]
         right_scores = self.rscorer(right_hidden).squeeze(-1)  # [N, R]
     else:
         left_scores = self.lscorer([left_cand_exprs,
                                     head_exprs]).squeeze(-1)  # [N, R]
         right_scores = self.rscorer([right_cand_exprs,
                                      head_exprs]).squeeze(-1)
     # mask
     left_scores += Constants.REAL_PRAC_MIN * (1. - left_masks)
     right_scores += Constants.REAL_PRAC_MIN * (1. - right_masks)
     return left_scores, right_scores
Beispiel #10
0
 def predict(self, insts: List, input_lexi, input_expr, input_mask):
     input_mask[:, 0] = 0.  # no artificial root
     final_score, attn, attn2 = self._score(input_expr, input_mask)
     pred_mask = self._predict(final_score, attn, attn2,
                               input_mask)  # [*, slen, L]
     sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds = self._pmask2idxes(
         pred_mask)
     all_logprobs = final_score.log().unsqueeze(-2) + (
         attn + 1e-10).log()  # [*, slen, L]
     bsize = len(insts)
     sel_lab_logprobs = all_logprobs[BK.arange_idx(bsize).unsqueeze(-1),
                                     sel_idxes, sel_lab_idxes]  # [*, ?]
     return sel_lab_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
Beispiel #11
0
 def __call__(self, input_v, add_root_token: bool):
     if isinstance(input_v, np.ndarray):
         # direct use this [batch_size, slen] as input
         posi_idxes = BK.input_idx(input_v)
         expr = self.node(posi_idxes)  # [batch_size, slen, D]
     else:
         # input is a shape as prepared by "PosiHelper"
         batch_size, max_len = input_v
         if add_root_token:
             max_len += 1
         posi_idxes = BK.arange_idx(max_len)  # [1?+slen] add root=0 here
         expr = self.node(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1)
     return self.dropout(expr)
Beispiel #12
0
 def _add_margin_inplaced(self, shape, hit_idxes0, hit_idxes1, hit_idxes2,
                          hit_labels, query_idxes0, query_idxes1,
                          query_idxes2, arc_scores, lab_scores,
                          arc_margin: float, lab_margin: float):
     # arc
     gold_arc_mat = BK.constants(shape, 0.)
     gold_arc_mat[hit_idxes0, hit_idxes1, hit_idxes2] = arc_margin
     gold_arc_margins = gold_arc_mat[query_idxes0, query_idxes1,
                                     query_idxes2]
     arc_scores -= gold_arc_margins
     if lab_scores is not None:
         # label
         gold_lab_mat = BK.constants_idx(shape,
                                         0)  # 0 means the padding idx
         gold_lab_mat[hit_idxes0, hit_idxes1, hit_idxes2] = hit_labels
         gold_lab_margin_idxes = gold_lab_mat[query_idxes0, query_idxes1,
                                              query_idxes2]
         lab_scores[BK.arange_idx(BK.get_shape(gold_lab_margin_idxes, 0)),
                    gold_lab_margin_idxes] -= lab_margin
     return
Beispiel #13
0
 def postprocess_scores(self, scores_expr, mask_expr, margin,
                        gold_heads_expr, gold_labels_expr):
     final_full_scores = scores_expr
     # first apply mask
     mask_value = Constants.REAL_PRAC_MIN
     mask_mul = (mask_value * (1. - mask_expr)).unsqueeze(-1)  # [*, len, 1]
     final_full_scores += mask_mul.unsqueeze(-2)
     final_full_scores += mask_mul.unsqueeze(-3)
     # then margin
     if margin > 0.:
         full_shape = BK.get_shape(final_full_scores)
         # combine the first two dim, and minus margin correspondingly
         combined_size = full_shape[0] * full_shape[1]
         combiend_score_expr = final_full_scores.view([combined_size] +
                                                      full_shape[-2:])
         arange_idx_expr = BK.arange_idx(combined_size)
         combiend_score_expr[arange_idx_expr,
                             gold_heads_expr.view(-1)] -= margin
         combiend_score_expr[arange_idx_expr,
                             gold_heads_expr.view(-1),
                             gold_labels_expr.view(-1)] -= margin
         final_full_scores = combiend_score_expr.view(full_shape)
     return final_full_scores
Beispiel #14
0
 def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.):
     # todo(+N): currently margin is not used
     conf = self.conf
     bsize = len(insts)
     arange_t = BK.arange_idx(bsize)
     assert conf.train_force, "currently only have forced training"
     # get the gold ones
     gold_widxes, gold_lidxes, gold_vmasks, ret_items, _ = self.batch_inputs_g1(insts)  # [*, ?]
     # for all the steps
     num_step = BK.get_shape(gold_widxes, -1)
     # recurrent states
     hard_coverage = BK.zeros(BK.get_shape(input_mask))  # [*, slen]
     prev_state = self.rnn_unit.zero_init_hidden(bsize)  # tuple([*, D], )
     all_tok_logprobs, all_lab_logprobs = [], []
     for cstep in range(num_step):
         slice_widx, slice_lidx = gold_widxes[:,cstep], gold_lidxes[:,cstep]
         _, sel_tok_logprobs, _, sel_lab_logprobs, _, next_state = \
             self._step(input_expr, input_mask, hard_coverage, prev_state, slice_widx, slice_lidx, None)
         all_tok_logprobs.append(sel_tok_logprobs)  # add one of [*, 1]
         all_lab_logprobs.append(sel_lab_logprobs)
         hard_coverage = BK.copy(hard_coverage)  # todo(note): cannot modify inplace!
         hard_coverage[arange_t, slice_widx] += 1.
         prev_state = [z.squeeze(-2) for z in next_state]
     # concat all the loss and mask
     # todo(note): no need to use gold_valid since things are telled in vmasks
     cat_tok_logprobs = BK.concat(all_tok_logprobs, -1) * gold_vmasks  # [*, steps]
     cat_lab_logprobs = BK.concat(all_lab_logprobs, -1) * gold_vmasks
     loss_sum = - (cat_tok_logprobs.sum() * conf.lambda_att + cat_lab_logprobs.sum() * conf.lambda_lab)
     # todo(+N): here we are dividing lab_logprobs with the all-count, do we need to separate?
     loss_count = gold_vmasks.sum()
     ret_losses = [[loss_sum, loss_count]]
     # =====
     # make eos unvalid for return
     ret_valid_mask = gold_vmasks * (gold_widxes>0).float()
     # embeddings
     sel_lab_embeds = self._hl_lookup(gold_lidxes)
     return ret_losses, ret_items, gold_widxes, ret_valid_mask, gold_lidxes, sel_lab_embeds
Beispiel #15
0
 def predict(self, ms_items: List, bert_expr, basic_expr):
     conf = self.conf
     bsize = len(ms_items)
     # todo(note): use the pred_events which are shallow copied from inputs
     offsets_t, masks_t, _, items_arr, _ = PrepHelper.prep_targets(
         ms_items, lambda x: x.pred_events, True, False, 0., 0.,
         True)  # [bs, ?]
     # -----
     if BK.get_shape(offsets_t, -1) == 0:
         return  # no input
     # -----
     # similar ones
     arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     sel_bert_t = bert_expr[arange_t, offsets_t]  # [bsize, ?, Fold, D]
     sel_basic_t = None if basic_expr is None else basic_expr[
         arange_t, offsets_t]  # [bsize, ?, D']
     hiddens = self.adp(sel_bert_t, sel_basic_t, [])  # [bsize, ?, D"]
     # predict: only top-1!
     if conf.pred_realis:
         self._pred_and_put_res(self.realis_predictor, hiddens, items_arr,
                                self._put_realis)
     if conf.pred_type:
         self._pred_and_put_res(self.type_predictor, hiddens, items_arr,
                                self._put_type)
Beispiel #16
0
 def predict(self, ms_items: List, bert_expr, basic_expr):
     conf = self.conf
     bsize = len(ms_items)
     # build targets (include all sents)
     offsets_t, masks_t, _, _, _ = PrepHelper.prep_targets(
         ms_items, lambda x: [], True, True, 1., 1., False)
     arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     sel_bert_t = bert_expr[arange_t, offsets_t]  # [bsize, ?, Fold, D]
     sel_basic_t = None if basic_expr is None else basic_expr[
         arange_t, offsets_t]  # [bsize, ?, D']
     hiddens = self.adp(sel_bert_t, sel_basic_t, [])  # [bsize, ?, D"]
     logits = self.predictor(hiddens)  # [bsize, ?, Out]
     # -----
     log_probs = BK.log_softmax(logits, -1)
     log_probs[:, :, 0] -= conf.nil_penalty  # encourage more predictions
     topk_log_probs, topk_log_labels = log_probs.max(
         dim=-1)  # [bsize, ?, k]
     # decoding
     head_offsets_arr = BK.get_value(offsets_t)  # [bs, ?]
     masks_arr = BK.get_value(masks_t)
     topk_log_probs_arr, topk_log_labels_arr = BK.get_value(
         topk_log_probs), BK.get_value(topk_log_labels)  # [bsize, ?, k]
     for one_ms_item, one_offsets_arr, one_masks_arr, one_logprobs_arr, one_labels_arr \
             in zip(ms_items, head_offsets_arr, masks_arr, topk_log_probs_arr, topk_log_labels_arr):
         # build tidx2sidx
         one_sents = one_ms_item.sents
         one_offsets = one_ms_item.offsets
         tidx2sidx = []
         for idx in range(1, len(one_offsets)):
             tidx2sidx.extend([idx - 1] *
                              (one_offsets[idx] - one_offsets[idx - 1]))
         # get all candidates
         all_candidates = [[] for _ in one_sents]
         for cur_offset, cur_valid, cur_logprob, cur_label in zip(
                 one_offsets_arr, one_masks_arr, one_logprobs_arr,
                 one_labels_arr):
             if not cur_valid or cur_label <= 0:
                 continue
             # which sent
             cur_offset = int(cur_offset)
             cur_sidx = tidx2sidx[cur_offset]
             cur_sent = one_sents[cur_sidx]
             minus_offset = one_ms_item.offsets[
                 cur_sidx] - 1  # again consider the ROOT
             cur_mention = Mention(
                 HardSpan(cur_sent.sid, cur_offset - minus_offset, None,
                          None))
             all_candidates[cur_sidx].append(
                 (cur_sent, cur_mention, cur_label, cur_logprob))
         # keep certain ratio for each sent separately?
         final_candidates = []
         if conf.pred_sent_ratio_sep:
             for one_sent, one_sent_candidates in zip(
                     one_sents, all_candidates):
                 cur_keep_num = max(
                     int(conf.pred_sent_ratio * (one_sent.length - 1)), 1)
                 one_sent_candidates.sort(key=lambda x: x[-1], reverse=True)
                 final_candidates.extend(one_sent_candidates[:cur_keep_num])
         else:
             all_size = 0
             for one_sent, one_sent_candidates in zip(
                     one_sents, all_candidates):
                 all_size += one_sent.length - 1
                 final_candidates.extend(one_sent_candidates)
             final_candidates.sort(key=lambda x: x[-1], reverse=True)
             final_keep_num = max(int(conf.pred_sent_ratio * all_size),
                                  len(one_sents))
             final_candidates = final_candidates[:final_keep_num]
         # add them all
         for cur_sent, cur_mention, cur_label, cur_logprob in final_candidates:
             cur_logprob = float(cur_logprob)
             doc_id = cur_sent.doc.doc_id
             self.id_counter[doc_id] += 1
             new_id = f"ef-{doc_id}-{self.id_counter[doc_id]}"
             hlidx = self.valid_hlidx
             new_ef = EntityFiller(new_id,
                                   cur_mention,
                                   str(hlidx),
                                   None,
                                   True,
                                   type_idx=hlidx,
                                   score=cur_logprob)
             cur_sent.pred_entity_fillers.append(new_ef)
Beispiel #17
0
 def loss(self, repr_t, attn_t, mask_t, disturb_keep_arr, **kwargs):
     conf = self.conf
     CR, PR = conf.cand_range, conf.pred_range
     # -----
     mask_single = BK.copy(mask_t)
     # no predictions for ARTI_ROOT
     if self.add_root_token:
         mask_single[:, 0] = 0.  # [bs, slen]
     # casting predicting range
     cur_slen = BK.get_shape(mask_single, -1)
     arange_t = BK.arange_idx(cur_slen)  # [slen]
     # [1, len] - [len, 1] = [len, len]
     reldist_t = (arange_t.unsqueeze(-2) - arange_t.unsqueeze(-1)
                  )  # [slen, slen]
     mask_pair = ((reldist_t.abs() <= CR) &
                  (reldist_t != 0)).float()  # within CR-range; [slen, slen]
     mask_pair = mask_pair * mask_single.unsqueeze(
         -1) * mask_single.unsqueeze(-2)  # [bs, slen, slen]
     if disturb_keep_arr is not None:
         mask_pair *= BK.input_real(1. - disturb_keep_arr).unsqueeze(
             -1)  # no predictions for the kept ones!
     # get all pair scores
     score_t = self.ps_node.paired_score(
         repr_t, repr_t, attn_t, maskp=mask_pair)  # [bs, len_q, len_k, 2*R]
     # -----
     # loss: normalize on which dim?
     # get the answers first
     if conf.pred_abs:
         answer_t = reldist_t.abs()  # [1,2,3,...,PR]
         answer_t.clamp_(
             min=0, max=PR -
             1)  # [slen, slen], clip in range, distinguish using masks
     else:
         answer_t = BK.where(
             (reldist_t >= 0), reldist_t - 1,
             reldist_t + 2 * PR)  # [1,2,3,...PR,-PR,-PR+1,...,-1]
         answer_t.clamp_(
             min=0, max=2 * PR -
             1)  # [slen, slen], clip in range, distinguish using masks
     # expand answer into idxes
     answer_hit_t = BK.zeros(BK.get_shape(answer_t) +
                             [2 * PR])  # [len_q, len_k, 2*R]
     answer_hit_t.scatter_(-1, answer_t.unsqueeze(-1), 1.)
     answer_valid_t = ((reldist_t.abs() <= PR) &
                       (reldist_t != 0)).float().unsqueeze(
                           -1)  # [bs, len_q, len_k, 1]
     answer_hit_t = answer_hit_t * mask_pair.unsqueeze(
         -1) * answer_valid_t  # clear invalid ones; [bs, len_q, len_k, 2*R]
     # get losses sum(log(answer*prob))
     # -- dim=-1 is standard 2*PR classification, dim=-2 usually have 2*PR candidates, but can be less at edges
     all_losses = []
     for one_dim, one_lambda in zip([-1, -2],
                                    [conf.lambda_n1, conf.lambda_n2]):
         if one_lambda > 0.:
             # since currently there can be only one or zero correct answer
             logprob_t = BK.log_softmax(score_t,
                                        one_dim)  # [bs, len_q, len_k, 2*R]
             sumlogprob_t = (logprob_t * answer_hit_t).sum(
                 one_dim)  # [bs, len_q, len_k||2*R]
             cur_dim_mask_t = (answer_hit_t.sum(one_dim) >
                               0.).float()  # [bs, len_q, len_k||2*R]
             # loss
             cur_dim_loss = -(sumlogprob_t * cur_dim_mask_t).sum()
             cur_dim_count = cur_dim_mask_t.sum()
             # argmax and corr (any correct counts)
             _, cur_argmax_idxes = score_t.max(one_dim)
             cur_corrs = answer_hit_t.gather(
                 one_dim, cur_argmax_idxes.unsqueeze(
                     one_dim))  # [bs, len_q, len_k|1, 2*R|1]
             cur_dim_corr_count = cur_corrs.sum()
             # compile loss
             one_loss = LossHelper.compile_leaf_info(
                 f"d{one_dim}",
                 cur_dim_loss,
                 cur_dim_count,
                 loss_lambda=one_lambda,
                 corr=cur_dim_corr_count)
             all_losses.append(one_loss)
     return self._compile_component_loss("orp", all_losses)
Beispiel #18
0
 def run(self, insts: List[DocInstance], training: bool):
     conf = self.conf
     BERT_MAX_LEN = 510  # save 2 for CLS and SEP
     # =====
     # encoder 1: the basic encoder
     # todo(note): only DocInstane input for this mode, otherwise will break
     if conf.m2e_use_basic:
         reidx_pad_len = conf.ms_extend_budget
         # enc the basic part + also get some indexes
         sentid2offset = {}  # id(sent)->overall_seq_offset
         seq_offset = 0  # if look at the docs in one seq
         all_sents = []  # (inst, d_idx, s_idx)
         for d_idx, one_doc in enumerate(insts):
             assert isinstance(one_doc, DocInstance)
             for s_idx, one_sent in enumerate(one_doc.sents):
                 # todo(note): here we encode all the sentences
                 all_sents.append((one_sent, d_idx, s_idx))
                 sentid2offset[id(one_sent)] = seq_offset
                 seq_offset += one_sent.length - 1  # exclude extra ROOT node
         sent_reprs = self.run_sents(all_sents, insts, training)
         # flatten and concatenate and re-index
         reidxes_arr = np.zeros(
             seq_offset + reidx_pad_len, dtype=np.long
         )  # todo(note): extra padding to avoid out of boundary
         all_flattened_reprs = []
         all_flatten_offset = 0  # the local offset for batched basic encoding
         for one_pack in sent_reprs:
             one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack
             assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
             one_repr_t = one_repr_evt
             _, one_slen, one_ldim = BK.get_shape(one_repr_t)
             all_flattened_reprs.append(one_repr_t.view([-1, one_ldim]))
             # fill in the indexes
             for one_sent in one_sents:
                 cur_start_offset = sentid2offset[id(one_sent)]
                 cur_real_slen = one_sent.length - 1
                 # again, +1 to get rid of extra ROOT
                 reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \
                     np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1)
                 all_flatten_offset += one_slen  # here add the slen in batched version
         # re-idxing
         seq_sent_repr0 = BK.concat(all_flattened_reprs, 0)
         seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr,
                                   0)  # [all_seq_len, D]
     else:
         sentid2offset = defaultdict(int)
         seq_sent_repr = None
     # =====
     # repack and prepare for multiple sent enc
     # todo(note): here, the criterion is based on bert's tokenizer
     all_ms_info = []
     if isinstance(insts[0], DocInstance):
         for d_idx, one_doc in enumerate(insts):
             for s_idx, x in enumerate(one_doc.sents):
                 # the basic criterion is the same as the basic one
                 include_flag = False
                 if training:
                     if x.length<self.train_skip_length and x.length>=self.train_min_length \
                             and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate):
                         include_flag = True
                 else:
                     if x.length >= self.test_min_length:
                         include_flag = True
                 if include_flag:
                     all_ms_info.append(
                         x.preps["ms"])  # use the pre-calculated one
     else:
         # multisent based
         all_ms_info = insts.copy()  # shallow copy
     # =====
     # encoder 2: the bert one (multi-sent encoding)
     ms_size_f = lambda x: x.subword_size
     all_ms_info.sort(key=ms_size_f)
     all_ms_buckets = self._bucket_sents_by_length(
         all_ms_info,
         conf.benc_bucket_range,
         ms_size_f,
         max_bsize=conf.benc_bucket_msize)
     berter = self.berter
     rets = []
     bert_use_center_typeids = conf.bert_use_center_typeids
     bert_use_special_typeids = conf.bert_use_special_typeids
     bert_other_inputs = conf.bert_other_inputs
     for one_bucket in all_ms_buckets:
         # prepare
         batched_ids = []
         batched_starts = []
         batched_seq_offset = []
         batched_typeids = []
         batched_other_inputs_list: List = [
             [] for _ in bert_other_inputs
         ]  # List(comp) of List(batch) of List(idx)
         for one_item in one_bucket:
             one_sents = one_item.sents
             one_center_sid = one_item.center_idx
             one_ids, one_starts, one_typeids = [], [], []
             one_other_inputs_list = [[] for _ in bert_other_inputs
                                      ]  # List(comp) of List(idx)
             for one_sid, one_sent in enumerate(one_sents):  # for bert
                 one_bidxes = one_sent.preps["bidx"]
                 one_ids.extend(one_bidxes.subword_ids)
                 one_starts.extend(one_bidxes.subword_is_start)
                 # prepare other inputs
                 for this_field_name, this_tofill_list in zip(
                         bert_other_inputs, one_other_inputs_list):
                     this_tofill_list.extend(
                         one_sent.preps["sub_" + this_field_name])
                 # todo(note): special procedure
                 if bert_use_center_typeids:
                     if one_sid != one_center_sid:
                         one_typeids.extend([0] *
                                            len(one_bidxes.subword_ids))
                     else:
                         this_typeids = [1] * len(one_bidxes.subword_ids)
                         if bert_use_special_typeids:
                             # todo(note): this is the special mode that we are given the events!!
                             for this_event in one_sents[
                                     one_center_sid].events:
                                 _, this_wid, this_wlen = this_event.mention.hard_span.position(
                                     headed=False)
                                 for a, b in one_item.center_word2sub[
                                         this_wid - 1:this_wid - 1 +
                                         this_wlen]:
                                     this_typeids[a:b] = [0] * (b - a)
                         one_typeids.extend(this_typeids)
             batched_ids.append(one_ids)
             batched_starts.append(one_starts)
             batched_typeids.append(one_typeids)
             for comp_one_oi, comp_batched_oi in zip(
                     one_other_inputs_list, batched_other_inputs_list):
                 comp_batched_oi.append(comp_one_oi)
             # for basic part
             batched_seq_offset.append(sentid2offset[id(one_sents[0])])
         # bert forward: [bs, slen, fold, D]
         if not bert_use_center_typeids:
             batched_typeids = None
         bert_expr0, mask_expr = berter.forward_batch(
             batched_ids,
             batched_starts,
             batched_typeids,
             training=training,
             other_inputs=batched_other_inputs_list)
         if self.m3_enc_is_empty:
             bert_expr = bert_expr0
         else:
             mask_arr = BK.get_value(mask_expr)  # [bs, slen]
             m3e_exprs = [
                 cur_enc(bert_expr0[:, :, cur_i], mask_arr)
                 for cur_i, cur_enc in enumerate(self.m3_encs)
             ]
             bert_expr = BK.stack(m3e_exprs, -2)  # on the fold dim again
         # collect basic ones: [bs, slen, D'] or None
         if seq_sent_repr is not None:
             arange_idxes_t = BK.arange_idx(BK.get_shape(
                 mask_expr, -1)).unsqueeze(0)  # [1, slen]
             offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze(
                 -1) + arange_idxes_t  # [bs, slen]
             basic_expr = seq_sent_repr[offset_idxes_t]  # [bs, slen, D']
         elif conf.m2e_use_basic_dep:
             # collect each token's head-bert and ud-label, then forward with adp
             fake_sents = [one_item.fake_sent for one_item in one_bucket]
             # head idx and labels, no artificial ROOT
             padded_head_arr, _ = self.dep_padder.pad(
                 [s.ud_heads.vals[1:] for s in fake_sents])
             padded_label_arr, _ = self.dep_padder.pad(
                 [s.ud_labels.idxes[1:] for s in fake_sents])
             # get tensor
             padded_head_t = (BK.input_idx(padded_head_arr) - 1
                              )  # here, the idx exclude root
             padded_head_t.clamp_(min=0)  # [bs, slen]
             padded_label_t = BK.input_idx(padded_label_arr)
             # get inputs
             input_head_bert_t = bert_expr[
                 BK.arange_idx(len(fake_sents)).unsqueeze(-1),
                 padded_head_t]  # [bs, slen, fold, D]
             input_label_emb_t = self.dep_label_emb(
                 padded_label_t)  # [bs, slen, D']
             basic_expr = self.dep_layer(
                 input_head_bert_t, None,
                 [input_label_emb_t])  # [bs, slen, ?]
         elif conf.m2e_use_basic_plus:
             sent_reprs = self.run_sents([(one_item.fake_sent, None, None)
                                          for one_item in one_bucket],
                                         insts,
                                         training,
                                         use_one_bucket=True)
             assert len(
                 sent_reprs
             ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range"
             _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0]
             assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
             basic_expr = one_repr_evt[:, 1:]  # exclude ROOT, [bs, slen, D]
             assert BK.get_shape(basic_expr)[:2] == BK.get_shape(
                 bert_expr)[:2]
         else:
             basic_expr = None
         # pack: (List[ms_item], bert_expr, basic_expr)
         rets.append((one_bucket, bert_expr, basic_expr))
     return rets
Beispiel #19
0
 def update_call(self,
                 cache: VRecCache,
                 src_mask=None,
                 qk_mask=None,
                 attn_range=None,
                 rel_dist=None,
                 temperature=1.,
                 forced_attn=None):
     conf = self.conf
     # -----
     # first call matt to get v
     matt_input_qk = cache.orig_t * conf.feat_qk_lambda_orig + cache.rec_t * (
         1. - conf.feat_qk_lambda_orig)
     if self.att_pre_norm:
         matt_input_qk = self.att_pre_norm(matt_input_qk)
     matt_input_v = cache.orig_t * conf.feat_v_lambda_orig + cache.rec_t * (
         1. - conf.feat_v_lambda_orig)
     # todo(note): currently no pre-norm for matt_input_v
     # put attn_range as mask_qk
     if attn_range is not None and attn_range >= 0:  # <0 means not effective
         cur_slen = BK.get_shape(matt_input_qk, -2)
         tmp_arange_t = BK.arange_idx(cur_slen)  # [slen]
         # less or equal!!
         mask_qk = (
             (tmp_arange_t.unsqueeze(-1) - tmp_arange_t.unsqueeze(0)).abs()
             <= attn_range).float()
         if qk_mask is not None:  # further with input masks
             mask_qk *= qk_mask
     else:
         mask_qk = qk_mask
     scores, attn_info, result_value = self.feat_node(
         matt_input_qk,
         matt_input_qk,
         matt_input_v,
         cache.accu_attn,
         mask_k=src_mask,
         mask_qk=mask_qk,
         rel_dist=rel_dist,
         temperature=temperature,
         forced_attn=forced_attn)  # ..., [*, len_q, dv]
     # -----
     # then combine q(hidden) and v(input)
     comb_input_q = cache.orig_t * conf.comb_q_lambda_orig + cache.rec_t * (
         1. - conf.comb_q_lambda_orig)
     comb_result, comb_c = self.comb_f(
         comb_input_q, result_value, cache.rec_lstm_c_t)  # [*, len_q, dim]
     if self.att_post_norm:
         comb_result = self.att_post_norm(comb_result)
     # -----
     # ff
     if self.has_ff:
         if self.ff_pre_norm:
             ff_input = self.ff_pre_norm(comb_result)
         else:
             ff_input = comb_result
         ff_output = comb_result + self.dropout2(
             self.linear2(self.dropout1(self.linear1(ff_input))))
         if self.ff_post_norm:
             ff_output = self.ff_post_norm(ff_output)
     else:  # otherwise no ff
         ff_output = comb_result
     # -----
     # update cache and return output
     # cache.orig_t = cache.orig_t  # this does not change
     cache.rec_t = ff_output
     cache.accu_attn = cache.accu_attn + attn_info[0]  # accumulating attn
     cache.rec_lstm_c_t = comb_c  # optional C for lstm
     cache.list_hidden.append(ff_output)  # all hidden layers
     cache.list_score.append(scores)  # all un-normed scores
     cache.list_attn.append(attn_info[0])  # all normed scores
     cache.list_accu_attn.append(cache.accu_attn)  # all accumulated attns
     cache.list_attn_info.append(attn_info)  # all attn infos
     return ff_output
Beispiel #20
0
 def predict(self, insts: List, input_lexi, input_expr, input_mask):
     conf = self.conf
     bsize, slen = BK.get_shape(input_mask)
     bsize_arange_t_1d = BK.arange_idx(bsize)  # [*]
     bsize_arange_t_2d = bsize_arange_t_1d.unsqueeze(-1)  # [*, 1]
     beam_size = conf.beam_size
     # prepare things with an extra beam dimension
     beam_input_expr, beam_input_mask = input_expr.unsqueeze(-3).expand(-1, beam_size, -1, -1).contiguous(), \
                                        input_mask.unsqueeze(-2).expand(-1, beam_size, -1).contiguous()  # [*, beam, slen, D?]
     # -----
     # recurrent states
     beam_hard_coverage = BK.zeros([bsize, beam_size, slen])  # [*, beam, slen]
     # tuple([*, beam, D], )
     beam_prev_state = [z.unsqueeze(-2).expand(-1, beam_size, -1) for z in self.rnn_unit.zero_init_hidden(bsize)]
     # frozen after reach eos
     beam_noneos = 1.-BK.zeros([bsize, beam_size])  # [*, beam]
     beam_logprobs = BK.zeros([bsize, beam_size])  # [*, beam], sum of logprobs
     beam_logprobs_paths = BK.zeros([bsize, beam_size, 0])  # [*, beam, step]
     beam_tok_paths = BK.zeros([bsize, beam_size, 0]).long()
     beam_lab_paths = BK.zeros([bsize, beam_size, 0]).long()
     # -----
     for cstep in range(conf.max_step):
         # get things of [*, beam, beam]
         sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state = \
             self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, None, None, beam_size)
         sel_logprobs = sel_tok_logprobs + sel_lab_logprobs  # [*, beam, beam]
         if cstep == 0:
             # special for the first step, only select for the first element
             cur_selections = BK.arange_idx(beam_size).unsqueeze(0).expand(bsize, beam_size)  # [*, beam]
         else:
             # then select the topk in beam*beam (be careful about the frozen ones!!)
             beam_noneos_3d = beam_noneos.unsqueeze(-1)
             # eos can only followed by eos
             sel_tok_idxes *= beam_noneos_3d.long()
             sel_lab_idxes *= beam_noneos_3d.long()
             # numeric tricks to keep the frozen ones ([0] with 0. score, [1:] with -inf scores)
             sel_logprobs *= beam_noneos_3d
             tmp_exclude_mask = 1. - beam_noneos_3d.expand_as(sel_logprobs)
             tmp_exclude_mask[:, :, 0] = 0.
             sel_logprobs += tmp_exclude_mask * Constants.REAL_PRAC_MIN
             # select for topk
             topk_logprobs = (beam_noneos * beam_logprobs).unsqueeze(-1) + sel_logprobs
             _, cur_selections = topk_logprobs.view([bsize, -1]).topk(beam_size, dim=-1, sorted=True)  # [*, beam]
         # read and write the selections
         # gathering previous ones
         cur_sel_previ = cur_selections // beam_size  # [*, beam]
         prev_hard_coverage = beam_hard_coverage[bsize_arange_t_2d, cur_sel_previ]  # [*, beam]
         prev_noneos = beam_noneos[bsize_arange_t_2d, cur_sel_previ]  # [*, beam]
         prev_logprobs = beam_logprobs[bsize_arange_t_2d, cur_sel_previ]  # [*, beam]
         prev_logprobs_paths = beam_logprobs_paths[bsize_arange_t_2d, cur_sel_previ]  # [*, beam, step]
         prev_tok_paths = beam_tok_paths[bsize_arange_t_2d, cur_sel_previ]  # [*, beam, step]
         prev_lab_paths = beam_lab_paths[bsize_arange_t_2d, cur_sel_previ]  # [*, beam, step]
         # prepare new ones
         cur_sel_newi = cur_selections % beam_size
         new_tok_idxes = sel_tok_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi]  # [*, beam]
         new_lab_idxes = sel_lab_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi]  # [*, beam]
         new_logprobs = sel_logprobs[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi]  # [*, beam]
         new_prev_state = [z[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] for z in next_state]  # [*, beam, ~]
         # update
         prev_hard_coverage[bsize_arange_t_2d, BK.arange_idx(beam_size).unsqueeze(0), new_tok_idxes] += 1.
         beam_hard_coverage = prev_hard_coverage
         beam_prev_state = new_prev_state
         beam_noneos = prev_noneos * (new_tok_idxes!=0).float()
         beam_logprobs = prev_logprobs + new_logprobs
         beam_logprobs_paths = BK.concat([prev_logprobs_paths, new_logprobs.unsqueeze(-1)], -1)
         beam_tok_paths = BK.concat([prev_tok_paths, new_tok_idxes.unsqueeze(-1)], -1)
         beam_lab_paths = BK.concat([prev_lab_paths, new_lab_idxes.unsqueeze(-1)], -1)
     # finally force an extra eos step to get ending tok-logprob (no need to update other things)
     final_eos_idxes = BK.zeros([bsize, beam_size]).long()
     _, eos_logprobs, _, _, _, _ = self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, final_eos_idxes, final_eos_idxes, None)
     beam_logprobs += eos_logprobs.squeeze(-1) * beam_noneos  # [*, beam]
     # select and return the best one
     beam_tok_valids = (beam_tok_paths > 0).float()  # [*, beam, steps]
     final_scores = beam_logprobs / ((beam_tok_valids.sum(-1) + 1.) ** conf.len_alpha)  # [*, beam]
     _, best_beam_idx = final_scores.max(-1)  # [*]
     # -----
     # prepare returns; cut by max length: [*, all_step] -> [*, max_step]
     ret0_valid_mask = beam_tok_valids[bsize_arange_t_1d, best_beam_idx]
     cur_max_step = ret0_valid_mask.long().sum(-1).max().item()
     ret_valid_mask = ret0_valid_mask[:, :cur_max_step]
     ret_logprobs = beam_logprobs_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step]
     ret_tok_idxes = beam_tok_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step]
     ret_lab_idxes = beam_lab_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step]
     # embeddings
     ret_lab_embeds = self._hl_lookup(ret_lab_idxes)
     return ret_logprobs, ret_tok_idxes, ret_valid_mask, ret_lab_idxes, ret_lab_embeds
Beispiel #21
0
 def update_bsize(self, new_bsize):
     if new_bsize != self.cur_bsize:
         self.cur_bsize = new_bsize
         self.bsize_range_t = BK.arange_idx(new_bsize)
Beispiel #22
0
 def get_rel_dist(self, len_q: int, len_k: int):
     dist_x = BK.arange_idx(0, len_k).unsqueeze(0)  # [1, len_k]
     dist_y = BK.arange_idx(0, len_q).unsqueeze(1)  # [len_q, 1]
     distance = dist_x - dist_y  # [len_q, len_k]
     return distance
Beispiel #23
0
 def loss(self, input_expr, loss_mask, gold_idxes, margin=0.):
     gold_all_idxes = self._get_all_idxes(gold_idxes)
     # scoring
     raw_scores = self._raw_scores(input_expr)
     raw_scores_aug = []
     margin_P, margin_R, margin_T = self.conf.margin_lambda_P, self.conf.margin_lambda_R, self.conf.margin_lambda_T
     #
     gold_shape = BK.get_shape(gold_idxes)  # [*]
     gold_bsize_prod = np.prod(gold_shape)
     # gold_arange_idxes = BK.arange_idx(gold_bsize_prod)
     # margin
     for i in range(self.eff_max_layer):
         cur_gold_inputs = gold_all_idxes[i]
         # add margin
         cur_scores = raw_scores[i]  # [*, ?]
         cur_margin = margin * self.margin_lambdas[i]
         if cur_margin > 0.:
             cur_num_target = self.prediction_sizes[i]
             cur_isnil = self.layered_isnil[i].byte()  # [NLab]
             cost_matrix = BK.constants([cur_num_target, cur_num_target],
                                        margin_T)  # [gold, pred]
             cost_matrix[cur_isnil, :] = margin_P
             cost_matrix[:, cur_isnil] = margin_R
             diag_idxes = BK.arange_idx(cur_num_target)
             cost_matrix[diag_idxes, diag_idxes] = 0.
             margin_mat = cost_matrix[cur_gold_inputs]
             cur_aug_scores = cur_scores + margin_mat  # [*, ?]
         else:
             cur_aug_scores = cur_scores
         raw_scores_aug.append(cur_aug_scores)
     # cascade scores
     final_scores = self._cascade_scores(raw_scores_aug)
     # loss weight, todo(note): asserted self.hl_vocab.nil_as_zero before
     loss_weights = ((gold_idxes == 0).float() *
                     (self.loss_fullnil_weight - 1.) +
                     1.) if self.loss_fullnil_weight < 1. else 1.
     # calculate loss
     loss_prob_entropy_lambda = self.conf.loss_prob_entropy_lambda
     loss_prob_reweight = self.conf.loss_prob_reweight
     final_losses = []
     no_loss_max_gold = self.conf.no_loss_max_gold
     if loss_mask is None:
         loss_mask = BK.constants(BK.get_shape(input_expr)[:-1], 1.)
     for i in range(self.eff_max_layer):
         cur_final_scores, cur_gold_inputs = final_scores[
             i], gold_all_idxes[i]  # [*, ?], [*]
         # collect the loss
         if self.is_hinge_loss:
             cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1)
             cur_gold_scores = BK.gather(cur_final_scores,
                                         cur_gold_inputs.unsqueeze(-1),
                                         -1).squeeze(-1)
             cur_loss = cur_pred_scores - cur_gold_scores  # [*], todo(note): this must be >=0
             if no_loss_max_gold:  # this should be implicit
                 cur_loss = cur_loss * (cur_loss > 0.).float()
         elif self.is_prob_loss:
             # cur_loss = BK.loss_nll(cur_final_scores, cur_gold_inputs)  # [*]
             cur_loss = self._my_loss_prob(cur_final_scores,
                                           cur_gold_inputs,
                                           loss_prob_entropy_lambda,
                                           loss_mask,
                                           loss_prob_reweight)  # [*]
             if no_loss_max_gold:
                 cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1)
                 cur_gold_scores = BK.gather(cur_final_scores,
                                             cur_gold_inputs.unsqueeze(-1),
                                             -1).squeeze(-1)
                 cur_loss = cur_loss * (cur_gold_scores >
                                        cur_pred_scores).float()
         else:
             raise NotImplementedError(
                 f"UNK loss {self.conf.loss_function}")
         # here first summing up, divided at the outside
         one_loss_sum = (
             cur_loss *
             (loss_mask * loss_weights)).sum() * self.loss_lambdas[i]
         final_losses.append(one_loss_sum)
     # final sum
     final_loss_sum = BK.stack(final_losses).sum()
     _, ret_lab_idxes, ret_lab_embeds = self._predict(final_scores, None)
     return [[final_loss_sum,
              loss_mask.sum()]], ret_lab_idxes, ret_lab_embeds
Beispiel #24
0
 def _loss(self,
           annotated_insts: List[ParseInstance],
           full_score_expr,
           mask_expr,
           valid_expr=None):
     bsize, max_len = BK.get_shape(mask_expr)
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in annotated_insts])
     # todo(note): here use the original idx of label, no shift!
     gold_labels_arr, _ = self.predict_padder.pad(
         [z.labels.idxes for z in annotated_insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
     #
     idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)
     idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0)
     # scores for decoding or marginal
     margin = self.margin.value
     decoding_scores = full_score_expr.clone().detach()
     decoding_scores = self.scorer_helper.postprocess_scores(
         decoding_scores, mask_expr, margin, gold_heads_expr,
         gold_labels_expr)
     if self.loss_hinge:
         mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                      dtype=np.int32)
         pred_heads_expr, pred_labels_expr, _ = nmst_unproj(decoding_scores,
                                                            mask_expr,
                                                            mst_lengths_arr,
                                                            labeled=True,
                                                            ret_arr=False)
         # ===== add margin*cost, [bs, len]
         gold_final_scores = full_score_expr[idxes_bs_expr, idxes_m_expr,
                                             gold_heads_expr,
                                             gold_labels_expr]
         pred_final_scores = full_score_expr[
             idxes_bs_expr, idxes_m_expr, pred_heads_expr,
             pred_labels_expr] + margin * (
                 gold_heads_expr != pred_heads_expr).float() + margin * (
                     gold_labels_expr !=
                     pred_labels_expr).float()  # plus margin
         hinge_losses = pred_final_scores - gold_final_scores
         valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) >
                         0.).float().unsqueeze(-1)  # [*, 1]
         final_losses = hinge_losses * valid_losses
     else:
         lab_marginals = nmarginal_unproj(decoding_scores,
                                          mask_expr,
                                          None,
                                          labeled=True)
         lab_marginals[idxes_bs_expr, idxes_m_expr, gold_heads_expr,
                       gold_labels_expr] -= 1.
         grads_masked = lab_marginals * mask_expr.unsqueeze(-1).unsqueeze(
             -1) * mask_expr.unsqueeze(-2).unsqueeze(-1)
         final_losses = (full_score_expr * grads_masked).sum(-1).sum(
             -1)  # [bs, m]
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_tok = sum(len(z) for z in annotated_insts)
     # exclude non-valid ones: there can be pruning error
     if valid_expr is not None:
         final_valids = valid_expr[idxes_bs_expr, idxes_m_expr,
                                   gold_heads_expr]  # [bs, m] of (0. or 1.)
         final_losses = final_losses * final_valids
         tok_valid = float(BK.get_value(final_valids[:, 1:].sum()))
         assert tok_valid <= num_valid_tok
         tok_prune_err = num_valid_tok - tok_valid
     else:
         tok_prune_err = 0
     # collect loss with mask, also excluding the first symbol of ROOT
     final_losses_masked = (final_losses * mask_expr)[:, 1:]
     final_loss_sum = BK.sum(final_losses_masked)
     if self.conf.tconf.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "tok": num_valid_tok,
         "tok_prune_err": tok_prune_err,
         "loss_sum": final_loss_sum_val
     }
     return final_loss, info