示例#1
0
文件: model.py 项目: ValentinaPy/zmsp
 def _fb_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef,
              evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt, margin):
     # get the gold idxes
     arg_linker = self.arg_linker
     bsize, len_ef = ef_items.shape
     bsize2, len_evt = evt_items.shape
     assert bsize == bsize2
     gold_idxes = np.zeros([bsize, len_ef, len_evt], dtype=np.long)
     for one_gold_idxes, one_ef_items, one_evt_items in zip(gold_idxes, ef_items, evt_items):
         # todo(note): check each pair
         for ef_idx, one_ef in enumerate(one_ef_items):
             if one_ef is None:
                 continue
             role_map = {id(z.evt): z.role_idx for z in one_ef.links}  # todo(note): since we get the original linked ones
             for evt_idx, one_evt in enumerate(one_evt_items):
                 pairwise_role_hlidx = role_map.get(id(one_evt))
                 if pairwise_role_hlidx is not None:
                     pairwise_role_idx = arg_linker.hlidx2idx(pairwise_role_hlidx)
                     assert pairwise_role_idx > 0
                     one_gold_idxes[ef_idx, evt_idx] = pairwise_role_idx
     # get loss
     repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2)  # [*, len-ef, D]
     repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2)  # [*, len-evt, D]
     if np.prod(gold_idxes.shape) == 0:
         # no instances!
         return [[BK.zeros([]), BK.zeros([])]]
     else:
         gold_idxes_t = BK.input_idx(gold_idxes)
         return arg_linker.loss(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask,
                                gold_idxes_t, margin)
示例#2
0
文件: model.py 项目: ValentinaPy/zmsp
 def _inference_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef,
                     evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt):
     arg_linker = self.arg_linker
     repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2)  # [*, len-ef, D]
     repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2)  # [*, len-evt, D]
     role_logprobs, role_predictions = arg_linker.predict(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes,
                                                          ef_valid_mask, evt_valid_mask)
     # add them inplaced
     roles_arr = BK.get_value(role_predictions)  # [*, len-ef, len-evt]
     logprobs_arr = BK.get_value(role_logprobs)
     for bidx, one_roles_arr in enumerate(roles_arr):
         one_ef_items, one_evt_items = ef_items[bidx], evt_items[bidx]
         # =====
         # todo(note): delete origin links!
         for z in one_ef_items:
             if z is not None:
                 z.links.clear()
         for z in one_evt_items:
             if z is not None:
                 z.links.clear()
         # =====
         one_logprobs = logprobs_arr[bidx]
         for ef_idx, one_ef in enumerate(one_ef_items):
             if one_ef is None:
                 continue
             for evt_idx, one_evt in enumerate(one_evt_items):
                 if one_evt is None:
                     continue
                 one_role_idx = int(one_roles_arr[ef_idx, evt_idx])
                 if one_role_idx > 0:  # link
                     this_hlidx = arg_linker.idx2hlidx(one_role_idx)
                     one_evt.add_arg(one_ef, role=str(this_hlidx), role_idx=this_hlidx,
                                     score=float(one_logprobs[ef_idx, evt_idx]))
示例#3
0
文件: head.py 项目: ValentinaPy/zmsp
 def _enc(self, input_lexi, input_expr, input_mask, sel_idxes):
     if self.dmxnn:
         bsize, slen = BK.get_shape(input_mask)
         if sel_idxes is None:
             sel_idxes = BK.arange_idx(slen).unsqueeze(
                 0)  # select all, [1, slen]
         ncand = BK.get_shape(sel_idxes, -1)
         # enc_expr aug with PE
         rel_dist = BK.arange_idx(slen).unsqueeze(0).unsqueeze(
             0) - sel_idxes.unsqueeze(-1)  # [*, ?, slen]
         pe_embeds = self.posi_embed(rel_dist)  # [*, ?, slen, Dpe]
         aug_enc_expr = BK.concat([
             pe_embeds.expand(bsize, -1, -1, -1),
             input_expr.unsqueeze(1).expand(-1, ncand, -1, -1)
         ], -1)  # [*, ?, slen, D+Dpe]
         # [*, ?, slen, Denc]
         hidden_expr = self.e_encoder(
             aug_enc_expr.view(bsize * ncand, slen, -1),
             input_mask.unsqueeze(1).expand(-1, ncand,
                                            -1).contiguous().view(
                                                bsize * ncand, slen))
         hidden_expr = hidden_expr.view(bsize, ncand, slen, -1)
         # dynamic max-pooling (dist<0, dist=0, dist>0)
         NEG = Constants.REAL_PRAC_MIN
         mp_hiddens = []
         mp_masks = [rel_dist < 0, rel_dist == 0, rel_dist > 0]
         for mp_mask in mp_masks:
             float_mask = mp_mask.float() * input_mask.unsqueeze(
                 -2)  # [*, ?, slen]
             valid_mask = (float_mask.sum(-1) > 0.).float().unsqueeze(
                 -1)  # [*, ?, 1]
             mask_neg_val = (
                 1. - float_mask).unsqueeze(-1) * NEG  # [*, ?, slen, 1]
             # todo(+2): or do we simply multiply mask?
             mp_hid0 = (hidden_expr + mask_neg_val).max(-2)[0]
             mp_hid = mp_hid0 * valid_mask  # [*, ?, Denc]
             mp_hiddens.append(self.special_drop(mp_hid))
             # mp_hiddens.append(mp_hid)
         final_hiddens = mp_hiddens
     else:
         hidden_expr = self.e_encoder(input_expr,
                                      input_mask)  # [*, slen, D']
         if sel_idxes is None:
             hidden_expr1 = hidden_expr
         else:
             hidden_expr1 = BK.gather_first_dims(hidden_expr, sel_idxes,
                                                 -2)  # [*, ?, D']
         final_hiddens = [self.special_drop(hidden_expr1)]
     if self.lab_f_use_lexi:
         final_hiddens.append(
             BK.gather_first_dims(input_lexi, sel_idxes,
                                  -2))  # [*, ?, DLex]
     ret_expr = self.lab_f(final_hiddens)  # [*, ?, DLab]
     return ret_expr
示例#4
0
 def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr):
     mask_idxes, mask_valids = BK.mask2idx(
         BK.input_real(pred_mask_repl_arr))  # [bsize, ?]
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz]]
     else:
         target_reprs = BK.gather_first_dims(repr_t, mask_idxes,
                                             1)  # [bsize, ?, *]
         target_hids = self.hid_layer(target_reprs)
         target_scores = self.pred_layer(target_hids)  # [bsize, ?, V]
         pred_idx_t = BK.input_idx(pred_idx_arr)  # [bsize, slen]
         target_idx_t = pred_idx_t.gather(-1, mask_idxes)  # [bsize, ?]
         target_idx_t[(mask_valids <
                       1.)] = 0  # make sure invalid ones in range
         # get loss
         pred_losses = BK.loss_nll(target_scores,
                                   target_idx_t)  # [bsize, ?]
         pred_loss_sum = (pred_losses * mask_valids).sum()
         pred_loss_count = mask_valids.sum()
         # argmax
         _, argmax_idxes = target_scores.max(-1)
         pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids
         pred_corr_count = pred_corrs.sum()
         return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
示例#5
0
 def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size):
     conf = self.conf
     free_mode = (force_widx is None)
     prev_state_h = prev_state[0]
     # =====
     # collect att scores
     key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)])  # [*, slen, h]
     query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)])  # [*, R, h]
     orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1))  # [*, slen, R]
     orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN  # [*, slen, R]
     # first maximum across the R dim (this step is hard max)
     maxr_scores, maxr_idxes = orig_scores.max(-1)  # [*, slen]
     if conf.zero_eos_score:
         # use mask to make it able to be backward
         tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.)
         tmp_mask.index_fill_(-1, BK.input_idx(0), 0.)
         maxr_scores *= tmp_mask
     # then select over the slen dim (this step is prob based)
     maxr_logprobs = BK.log_softmax(maxr_scores)  # [*, slen]
     if free_mode:
         cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1))
         sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False)  # [*, beam]
     else:
         sel_tok_idxes = force_widx.unsqueeze(-1)  # [*, 1]
         sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes)  # [*, 1]
     # then collect the info and perform labeling
     lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2)  # [*, ?, ~]
     lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1)  # [*, ?, 1]
     lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)]  # [*, ?, ~]  # todo(+3): using soft version?
     lf_prev_state = prev_state_h.unsqueeze(-2)  # [*, 1, ~]
     lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state])  # [*, ?, ~]
     # final predicting labels
     # todo(+N): here we select only max at labeling part, only beam at previous one
     if free_mode:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None)  # [*, ?]
     else:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1))
     # no lab-logprob (*=0) for eos (sel_tok==0)
     sel_lab_logprobs *= (sel_tok_idxes>0).float()
     # compute next-state [*, ?, ~]
     # todo(note): here we flatten the first two dims
     tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1]
     tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1)
     tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1))
     tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1))
                       for z in prev_state]  # [*, ?, ?, D]
     next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None)
     next_state = [z.view(tmp_rnn_dims) for z in next_state]
     return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
示例#6
0
文件: head.py 项目: ValentinaPy/zmsp
 def _exclude_nil(self,
                  sel_idxes,
                  sel_valid_mask,
                  sel_lab_idxes,
                  sel_lab_embeds,
                  sel_logprobs=None,
                  sel_items_arr=None):
     # todo(note): assure that nil is 0
     sel_valid_mask = sel_valid_mask * (sel_lab_idxes !=
                                        0).float()  # not inplaced
     # idx on idx
     s2_idxes, s2_valid_mask = BK.mask2idx(sel_valid_mask)
     sel_idxes = sel_idxes.gather(-1, s2_idxes)
     sel_valid_mask = s2_valid_mask
     sel_lab_idxes = sel_lab_idxes.gather(-1, s2_idxes)
     sel_lab_embeds = BK.gather_first_dims(sel_lab_embeds, s2_idxes, -2)
     sel_logprobs = None if sel_logprobs is None else sel_logprobs.gather(
         -1, s2_idxes)
     sel_items_arr = None if sel_items_arr is None \
         else sel_items_arr[np.arange(len(sel_items_arr))[:, np.newaxis], BK.get_value(s2_idxes)]
     return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs, sel_items_arr
示例#7
0
 def forward_batch(self,
                   batched_ids: List,
                   batched_starts: List,
                   batched_typeids: List,
                   training: bool,
                   other_inputs: List[List] = None):
     conf = self.bconf
     tokenizer = self.tokenizer
     PAD_IDX = tokenizer.pad_token_id
     MASK_IDX = tokenizer.mask_token_id
     CLS_IDX = tokenizer.cls_token_id
     SEP_IDX = tokenizer.sep_token_id
     if other_inputs is None:
         other_inputs = []
     # =====
     # batch: here add CLS and SEP
     bsize = len(batched_ids)
     max_len = max(len(z) for z in batched_ids) + 2  # plus [CLS] and [SEP]
     input_shape = (bsize, max_len)
     # first collect on CPU
     input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64)
     input_ids_arr[:, 0] = CLS_IDX
     input_mask_arr = np.full(input_shape, 0, dtype=np.float32)
     input_is_start_arr = np.full(input_shape, 0, dtype=np.int64)
     input_typeids = None if batched_typeids is None else np.full(
         input_shape, 0, dtype=np.int64)
     other_input_arrs = [
         np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs
     ]
     if conf.bert2_retinc_cls:  # act as the ROOT word
         input_is_start_arr[:, 0] = 1
     training_mask_rate = conf.bert2_training_mask_rate if training else 0.
     self_sample_stream = self.random_sample_stream
     for bidx in range(bsize):
         cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx]
         cur_end = len(cur_ids) + 2  # plus CLS and SEP
         if training_mask_rate > 0.:
             # input dropout
             input_ids_arr[bidx, 1:cur_end] = [
                 (MASK_IDX
                  if next(self_sample_stream) < training_mask_rate else z)
                 for z in cur_ids
             ] + [SEP_IDX]
         else:
             input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX]
         input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts
         input_mask_arr[bidx, :cur_end] = 1.
         if batched_typeids is not None and batched_typeids[
                 bidx] is not None:
             input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx]
         for one_other_input_arr, one_other_input_list in zip(
                 other_input_arrs, other_inputs):
             one_other_input_arr[bidx,
                                 1:cur_end - 1] = one_other_input_list[bidx]
     # arr to tensor
     input_ids_t = BK.input_idx(input_ids_arr)
     input_mask_t = BK.input_real(input_mask_arr)
     input_is_start_t = BK.input_idx(input_is_start_arr)
     input_typeid_t = None if input_typeids is None else BK.input_idx(
         input_typeids)
     other_input_ts = [BK.input_idx(z) for z in other_input_arrs]
     # =====
     # forward (maybe need multiple times to fit maxlen constraint)
     MAX_LEN = 510  # save two for [CLS] and [SEP]
     BACK_LEN = 100  # for splitting cases, still remaining some of previous sub-tokens for context
     if max_len <= MAX_LEN:
         # directly once
         final_outputs = self.forward_features(
             input_ids_t, input_mask_t, input_typeid_t,
             other_input_ts)  # [bs, slen, *...]
         start_idxes, start_masks = BK.mask2idx(
             input_is_start_t.float())  # [bsize, ?]
     else:
         all_outputs = []
         cur_sub_idx = 0
         slice_size = [bsize, 1]
         slice_cls, slice_sep = BK.constants(slice_size,
                                             CLS_IDX,
                                             dtype=BK.int64), BK.constants(
                                                 slice_size,
                                                 SEP_IDX,
                                                 dtype=BK.int64)
         while cur_sub_idx < max_len - 1:  # minus 1 to ignore ending SEP
             cur_slice_start = max(1, cur_sub_idx - BACK_LEN)
             cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1)
             cur_input_ids_t = BK.concat([
                 slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end],
                 slice_sep
             ], 1)
             # here we simply extend extra original masks
             cur_input_mask_t = input_mask_t[:, cur_slice_start -
                                             1:cur_slice_end + 1]
             cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:,
                                                                                     cur_slice_start
                                                                                     -
                                                                                     1:
                                                                                     cur_slice_end
                                                                                     +
                                                                                     1]
             cur_other_input_ts = [
                 z[:, cur_slice_start - 1:cur_slice_end + 1]
                 for z in other_input_ts
             ]
             cur_outputs = self.forward_features(cur_input_ids_t,
                                                 cur_input_mask_t,
                                                 cur_input_typeid_t,
                                                 cur_other_input_ts)
             # only include CLS in the first run, no SEP included
             if cur_sub_idx == 0:
                 # include CLS, exclude SEP
                 all_outputs.append(cur_outputs[:, :-1])
             else:
                 # include only new ones, discard BACK ones, exclude CLS, SEP
                 all_outputs.append(cur_outputs[:, cur_sub_idx -
                                                cur_slice_start + 1:-1])
                 zwarn(
                     f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] "
                     f"for all-len={max_len}")
             cur_sub_idx = cur_slice_end
         final_outputs = BK.concat(all_outputs, 1)  # [bs, max_len-1, *...]
         start_idxes, start_masks = BK.mask2idx(
             input_is_start_t[:, :-1].float())  # [bsize, ?]
     start_expr = BK.gather_first_dims(final_outputs, start_idxes,
                                       1)  # [bsize, ?, *...]
     return start_expr, start_masks  # [bsize, ?, ...], [bsize, ?]
示例#8
0
 def loss(self,
          repr_ts,
          input_erase_mask_arr,
          orig_map: Dict,
          active_hid=True,
          **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # prepare idxes for the masked ones
     if self.add_root_token:  # offset for the special root added in embedder
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr),
             padding_idx=-1)  # [bsize, ?]
         repr_mask_idxes = mask_idxes + 1
         mask_idxes.clamp_(min=0)
     else:
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr))  # [bsize, ?]
         repr_mask_idxes = mask_idxes
     # get the losses
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         return self._compile_component_loss("mlm", [])
     else:
         if not isinstance(repr_ts, (List, Tuple)):
             repr_ts = [repr_ts]
         target_word_scores, target_pos_scores = [], []
         target_pos_scores = None  # todo(+N): for simplicity, currently ignore this one!!
         for layer_idx in conf.loss_layers:
             # calculate scores
             target_reprs = BK.gather_first_dims(repr_ts[layer_idx],
                                                 repr_mask_idxes,
                                                 1)  # [bsize, ?, *]
             if self.hid_layer and active_hid:  # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside!
                 target_hids = self.hid_layer(target_reprs)
             else:
                 target_hids = target_reprs
             if _tie_input_embeddings:
                 pred_W = self.inputter_word_node.E.E[:self.
                                                      pred_word_size]  # [PSize, Dim]
                 target_word_scores.append(BK.matmul(
                     target_hids, pred_W.T))  # List[bsize, ?, Vw]
             else:
                 target_word_scores.append(self.pred_word_layer(
                     target_hids))  # List[bsize, ?, Vw]
         # gather the losses
         all_losses = []
         for pred_name, target_scores, loss_lambda, range_min, range_max in \
                 zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos],
                     [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]):
             if loss_lambda > 0.:
                 seq_idx_t = BK.input_idx(
                     orig_map[pred_name])  # [bsize, slen]
                 target_idx_t = seq_idx_t.gather(-1,
                                                 mask_idxes)  # [bsize, ?]
                 ranged_mask_valids = mask_valids * (
                     target_idx_t >= range_min).float() * (
                         target_idx_t <= range_max).float()
                 target_idx_t[(ranged_mask_valids <
                               1.)] = 0  # make sure invalid ones in range
                 # calculate for each layer
                 all_layer_losses, all_layer_scores = [], []
                 for one_layer_idx, one_target_scores in enumerate(
                         target_scores):
                     # get loss: [bsize, ?]
                     one_pred_losses = BK.loss_nll(
                         one_target_scores,
                         target_idx_t) * conf.loss_weights[one_layer_idx]
                     all_layer_losses.append(one_pred_losses)
                     # get scores
                     one_pred_scores = BK.log_softmax(
                         one_target_scores,
                         -1) * conf.loss_weights[one_layer_idx]
                     all_layer_scores.append(one_pred_scores)
                 # combine all layers
                 pred_losses = self.loss_comb_f(all_layer_losses)
                 pred_loss_sum = (pred_losses * ranged_mask_valids).sum()
                 pred_loss_count = ranged_mask_valids.sum()
                 # argmax
                 _, argmax_idxes = self.score_comb_f(all_layer_scores).max(
                     -1)
                 pred_corrs = (argmax_idxes
                               == target_idx_t).float() * ranged_mask_valids
                 pred_corr_count = pred_corrs.sum()
                 # compile leaf loss
                 r_loss = LossHelper.compile_leaf_info(
                     pred_name,
                     pred_loss_sum,
                     pred_loss_count,
                     loss_lambda=loss_lambda,
                     corr=pred_corr_count)
                 all_losses.append(r_loss)
         return self._compile_component_loss("mlm", all_losses)