Example #1
0
 def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs):
     conf = self.conf
     # scoring
     arc_score, lab_score = self._score(enc_expr,
                                        mask_expr)  # [bs, m, h, *]
     # loss
     bsize, max_len = BK.get_shape(mask_expr)
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in insts])
     # todo(note): here use the original idx of label, no shift!
     gold_labels_arr, _ = self.predict_padder.pad(
         [z.labels.idxes for z in insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [bs, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [bs, Len]
     # collect the losses
     arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)  # [bs, 1]
     arange_m_expr = BK.arange_idx(max_len).unsqueeze(0)  # [1, Len]
     # logsoftmax and losses
     arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1),
                                      -1)  # [bs, m, h]
     lab_logsoftmaxs = BK.log_softmax(lab_score, -1)  # [bs, m, h, Lab]
     arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr,
                                  gold_heads_expr]  # [bs, Len]
     lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr,
                                  gold_heads_expr,
                                  gold_labels_expr]  # [bs, Len]
     # head selection (no root)
     arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum()
     lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum()
     final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum
     final_loss_count = mask_expr[:, 1:].sum()
     return [[final_loss, final_loss_count]]
Example #2
0
 def batch_inputs_h(self, insts: List[Sentence]):
     key, items_getter = self.extract_type, self.items_getter
     nil_idx = 0
     # get gold/input data and batch
     all_masks, all_idxes, all_items, all_valid = [], [], [], []
     all_idxes2, all_items2 = [], []  # secondary types
     for sent in insts:
         preps = sent.preps.get(key)
         # not cached, rebuild them
         if preps is None:
             length = sent.length
             items = items_getter(sent)
             # token-idx -> ...
             prep_masks, prep_idxes, prep_items = [0.] * length, [
                 nil_idx
             ] * length, [None] * length
             prep_idxes2, prep_items2 = [nil_idx] * length, [None] * length
             if items is None:
                 # todo(note): there are samples that do not have entity annotations (KBP15)
                 #  final 0/1 indicates valid or not
                 prep_valid = 0.
             else:
                 prep_valid = 1.
                 for one_item in items:
                     this_hwidx = one_item.mention.hard_span.head_wid
                     this_hlidx = one_item.type_idx
                     # todo(+N): ignore except the first two types (already ranked by type-freq)
                     if prep_idxes[this_hwidx] == 0:
                         prep_masks[this_hwidx] = 1.
                         prep_idxes[this_hwidx] = self.hlidx2idx(
                             this_hlidx)  # change to int here!
                         prep_items[this_hwidx] = one_item
                     elif prep_idxes2[this_hwidx] == 0:
                         prep_idxes2[this_hwidx] = self.hlidx2idx(
                             this_hlidx)  # change to int here!
                         prep_items2[this_hwidx] = one_item
             sent.preps[key] = (prep_masks, prep_idxes, prep_items,
                                prep_valid, prep_idxes2, prep_items2)
         else:
             prep_masks, prep_idxes, prep_items, prep_valid, prep_idxes2, prep_items2 = preps
         # =====
         all_masks.append(prep_masks)
         all_idxes.append(prep_idxes)
         all_items.append(prep_items)
         all_valid.append(prep_valid)
         all_idxes2.append(prep_idxes2)
         all_items2.append(prep_items2)
     # pad and batch
     mention_masks = BK.input_real(
         self.padder_mask.pad(all_masks)[0])  # [*, slen]
     mention_idxes = BK.input_idx(
         self.padder_idxes.pad(all_idxes)[0])  # [*, slen]
     mention_items_arr, _ = self.padder_items.pad(all_items)  # [*, slen]
     mention_valid = BK.input_real(all_valid)  # [*]
     mention_idxes2 = BK.input_idx(
         self.padder_idxes.pad(all_idxes2)[0])  # [*, slen]
     mention_items2_arr, _ = self.padder_items.pad(all_items2)  # [*, slen]
     return mention_masks, mention_idxes, mention_items_arr, mention_valid, mention_idxes2, mention_items2_arr
Example #3
0
 def pad_par(self, idxes: List, labels: List):
     par_idxes_t = BK.input_idx(idxes)
     labels_t = BK.input_idx(labels)
     # todo(note): specifically, <0 means non-exist
     # todo(note): an interesting bug, the bug is ">=" was wrongly written as "<", in this way, 0 will act as the parent of those who actually do not have parents and are to be attached, therefore maybe patterns of "parent=0" will get much positive scores
     # todo(note): ACTUALLY, mainly because of the difference in search and forward-backward!!
     par_mask_t = (par_idxes_t >= 0).float()
     par_idxes_t.clamp_(0)  # since -1 will be illegal idx
     labels_t.clamp_(0)
     return par_idxes_t, labels_t, par_mask_t
Example #4
0
 def _collect_insts(self, ms_items: List, training):
     max_range = self.conf.max_range
     ret_efs, ret_sents, ret_bidxes, ret_head_idxes, ret_left_dists, ret_right_dists = [], [], [], [], [], []
     for batch_idx, one_item in enumerate(ms_items):
         one_sents = one_item.sents
         sid2sents = {s.sid: s
                      for s in one_sents}  # not the sid in this list
         sid2offsets = {
             s.sid: v
             for s, v in zip(one_sents, one_item.offsets)
         }  # not the sid in this list
         # assert one_sents[0].sid == 0, "Currently only support fake doc!"
         one_center_idx = one_item.center_idx
         one_center_sent = one_sents[one_center_idx]
         # get target events
         one_center_evts = one_center_sent.events if training else one_center_sent.pred_events
         if one_center_evts is not None and len(one_center_evts) > 0:
             # todo(+N): is multi-event ok?
             # assert len(one_center_evts) == 1, "Currently only support one event at one sent!!"
             # get args
             for one_center_evt in one_center_evts:
                 if one_center_evt.links is None:
                     continue
                 for one_arg in one_center_evt.links:
                     one_ef = one_arg.ef
                     # only collect in-ranged ones
                     if one_ef.mention is not None and one_ef.mention.hard_span.sid in sid2sents:
                         hspan = one_ef.mention.hard_span
                         sid, head_wid, wid, wlen = hspan.sid, hspan.head_wid, hspan.wid, hspan.length
                         left_dist = head_wid - wid
                         right_dist = wid + wlen - 1 - head_wid
                         if training:
                             if left_dist >= max_range or right_dist >= max_range:
                                 continue  # skip long spans in training
                         else:
                             # clear wid and wlen for testing
                             hspan.wid = hspan.head_wid
                             hspan.length = 1
                             left_dist = right_dist = 0
                         # add one
                         ret_sents.append(sid2sents[sid])
                         ret_efs.append(
                             one_ef
                         )  # todo(note): may repeat but does not matter
                         ret_bidxes.append(batch_idx)
                         ret_head_idxes.append(sid2offsets[sid] + head_wid -
                                               1)  # minus ROOT offset
                         ret_left_dists.append(left_dist)
                         ret_right_dists.append(right_dist)
     return ret_efs, ret_sents, BK.input_idx(ret_bidxes), BK.input_idx(ret_head_idxes), \
            BK.input_idx(ret_left_dists), BK.input_idx(ret_right_dists)
Example #5
0
 def run_sents(self, all_sents: List, all_docs: List[DocInstance], training: bool, use_one_bucket=False):
     if use_one_bucket:
         all_buckets = [all_sents]  # when we do not want to split if we know the input lengths do not vary too much
     else:
         all_sents.sort(key=lambda x: x[0].length)
         all_buckets = self._bucket_sents_by_length(all_sents, self.bconf.enc_bucket_range)
     # doc hint
     use_doc_hint = self.use_doc_hint
     if use_doc_hint:
         dh_sent_repr = self.dh_node.run(all_docs)  # [NumDoc, MaxSent, D]
     else:
         dh_sent_repr = None
     # encoding for each of the bucket
     rets = []
     dh_add, dh_both, dh_cls = [self.dh_combine_method==z for z in ["add", "both", "cls"]]
     for one_bucket in all_buckets:
         one_sents = [z[0] for z in one_bucket]
         # [BS, Len, Di], [BS, Len]
         input_repr0, mask_arr0 = self._prepare_input(one_sents, training)
         if use_doc_hint:
             one_d_idxes = BK.input_idx([z[1] for z in one_bucket])
             one_s_idxes = BK.input_idx([z[2] for z in one_bucket])
             one_s_reprs = dh_sent_repr[one_d_idxes, one_s_idxes].unsqueeze(-2)  # [BS, 1, D]
             if dh_add:
                 input_repr = input_repr0 + one_s_reprs  # [BS, slen, D]
                 mask_arr = mask_arr0
             elif dh_both:
                 input_repr = BK.concat([one_s_reprs, input_repr0, one_s_reprs], -2)  # [BS, 2+slen, D]
                 mask_arr = np.pad(mask_arr0, ((0,0),(1,1)), 'constant', constant_values=1.)  # [BS, 2+slen]
             elif dh_cls:
                 input_repr = BK.concat([one_s_reprs, input_repr0[:, 1:]], -2)  # [BS, slen, D]
                 mask_arr = mask_arr0
             else:
                 raise NotImplementedError()
         else:
             input_repr, mask_arr = input_repr0, mask_arr0
         # [BS, Len, De]
         enc_repr = self.enc(input_repr, mask_arr)
         # separate ones (possibly using detach to avoid gradients for some of them)
         enc_repr_ef = self.enc_ef(enc_repr.detach() if self.bconf.enc_ef_input_detach else enc_repr, mask_arr)
         enc_repr_evt = self.enc_evt(enc_repr.detach() if self.bconf.enc_evt_input_detach else enc_repr, mask_arr)
         if use_doc_hint and dh_both:
             one_ret = (one_sents, input_repr0, enc_repr_ef[:, 1:-1].contiguous(), enc_repr_evt[:, 1:-1].contiguous(), mask_arr0)
         else:
             one_ret = (one_sents, input_repr0, enc_repr_ef, enc_repr_evt, mask_arr0)
         rets.append(one_ret)
     # todo(note): returning tuple is (List[Sentence], Tensor, Tensor, Tensor)
     return rets
Example #6
0
 def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
     conf = self.conf
     # score
     scores_t = self._score(repr_t)  # [bs, ?+rlen, D]
     # get gold
     gold_pidxes = np.zeros(BK.get_shape(mask_t),
                            dtype=np.long)  # [bs, ?+rlen]
     for bidx, inst in enumerate(insts):
         cur_seq_idxes = getattr(inst, self.attr_name).idxes
         if self.add_root_token:
             gold_pidxes[bidx, 1:1 + len(cur_seq_idxes)] = cur_seq_idxes
         else:
             gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes
     # get loss
     margin = self.margin.value
     gold_pidxes_t = BK.input_idx(gold_pidxes)
     gold_pidxes_t *= (gold_pidxes_t <
                       self.pred_out_dim).long()  # 0 means invalid ones!!
     loss_mask_t = (gold_pidxes_t > 0).float() * mask_t  # [bs, ?+rlen]
     lab_losses_t = BK.loss_nll(scores_t, gold_pidxes_t,
                                margin=margin)  # [bs, ?+rlen]
     # argmax
     _, argmax_idxes = scores_t.max(-1)
     pred_corrs = (argmax_idxes == gold_pidxes_t).float() * loss_mask_t
     # compile loss
     lab_loss = LossHelper.compile_leaf_info("slab",
                                             lab_losses_t.sum(),
                                             loss_mask_t.sum(),
                                             corr=pred_corrs.sum())
     return self._compile_component_loss(self.pname, [lab_loss])
Example #7
0
 def loss(self, ms_items: List, bert_expr, basic_expr):
     conf = self.conf
     bsize = len(ms_items)
     # use gold targets: only use positive samples!!
     offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets(
         ms_items, lambda x: x.events, True, False, 0., 0., True)  # [bs, ?]
     realis_flist = [(-1 if
                      (z is None or z.realis_idx is None) else z.realis_idx)
                     for z in items_arr.flatten()]
     realis_t = BK.input_idx(realis_flist).view(items_arr.shape)  # [bs, ?]
     realis_mask = (realis_t >= 0).float()
     realis_t.clamp_(min=0)  # make sure all idxes are legal
     # -----
     # return 0 if all no targets
     if BK.get_shape(offsets_t, -1) == 0:
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz], [zzz, zzz, zzz]]  # realis, types
     # -----
     arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     sel_bert_t = bert_expr[arange_t, offsets_t]  # [bsize, ?, Fold, D]
     sel_basic_t = None if basic_expr is None else basic_expr[
         arange_t, offsets_t]  # [bsize, ?, D']
     hiddens = self.adp(sel_bert_t, sel_basic_t, [])  # [bsize, ?, D"]
     # build losses
     loss_item_realis = self._get_one_loss(self.realis_predictor, hiddens,
                                           realis_t, realis_mask,
                                           conf.lambda_realis)
     loss_item_type = self._get_one_loss(self.type_predictor, hiddens,
                                         labels_t, masks_t,
                                         conf.lambda_type)
     return [loss_item_realis, loss_item_type]
Example #8
0
 def _loss(self, enc_repr, action_list: List[EfAction], arc_weight_list: List[float], label_weight_list: List[float],
           bidxes_list: List[int]):
     # 1. collect (batched) features; todo(note): use prev state for scoring
     hm_features = self.hm_feature_getter.get_hm_features(action_list, [a.state_from for a in action_list])
     # 2. get new sreprs
     scorer = self.scorer
     s_enc = self.slayer
     bsize_range_t = BK.input_idx(bidxes_list)
     node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], enc_repr, bsize_range_t)
     node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], enc_repr, bsize_range_t)
     # label loss
     if self.system_labeled:
         node_lh_expr, _ = scorer.transform_space_label(node_h_srepr, True, False)
         _, node_lm_pack = scorer.transform_space_label(node_m_srepr, False, True)
         label_scores_full = scorer.score_label(node_lm_pack, node_lh_expr)  # [*, Lab]
         label_scores = BK.gather_one_lastdim(label_scores_full, [a.label for a in action_list]).squeeze(-1)
         final_label_loss_sum = (label_scores * BK.input_real(label_weight_list)).sum()
     else:
         label_scores = final_label_loss_sum = BK.zeros([])
     # arc loss
     node_ah_expr, _ = scorer.transform_space_arc(node_h_srepr, True, False)
     _, node_am_pack = scorer.transform_space_arc(node_m_srepr, False, True)
     arc_scores = scorer.score_arc(node_am_pack, node_ah_expr).squeeze(-1)
     final_arc_loss_sum = (arc_scores * BK.input_real(arc_weight_list)).sum()
     # score reg
     return final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores
Example #9
0
 def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr):
     mask_idxes, mask_valids = BK.mask2idx(
         BK.input_real(pred_mask_repl_arr))  # [bsize, ?]
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz]]
     else:
         target_reprs = BK.gather_first_dims(repr_t, mask_idxes,
                                             1)  # [bsize, ?, *]
         target_hids = self.hid_layer(target_reprs)
         target_scores = self.pred_layer(target_hids)  # [bsize, ?, V]
         pred_idx_t = BK.input_idx(pred_idx_arr)  # [bsize, slen]
         target_idx_t = pred_idx_t.gather(-1, mask_idxes)  # [bsize, ?]
         target_idx_t[(mask_valids <
                       1.)] = 0  # make sure invalid ones in range
         # get loss
         pred_losses = BK.loss_nll(target_scores,
                                   target_idx_t)  # [bsize, ?]
         pred_loss_sum = (pred_losses * mask_valids).sum()
         pred_loss_count = mask_valids.sum()
         # argmax
         _, argmax_idxes = target_scores.max(-1)
         pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids
         pred_corr_count = pred_corrs.sum()
         return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
Example #10
0
 def _fb_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef,
              evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt, margin):
     # get the gold idxes
     arg_linker = self.arg_linker
     bsize, len_ef = ef_items.shape
     bsize2, len_evt = evt_items.shape
     assert bsize == bsize2
     gold_idxes = np.zeros([bsize, len_ef, len_evt], dtype=np.long)
     for one_gold_idxes, one_ef_items, one_evt_items in zip(gold_idxes, ef_items, evt_items):
         # todo(note): check each pair
         for ef_idx, one_ef in enumerate(one_ef_items):
             if one_ef is None:
                 continue
             role_map = {id(z.evt): z.role_idx for z in one_ef.links}  # todo(note): since we get the original linked ones
             for evt_idx, one_evt in enumerate(one_evt_items):
                 pairwise_role_hlidx = role_map.get(id(one_evt))
                 if pairwise_role_hlidx is not None:
                     pairwise_role_idx = arg_linker.hlidx2idx(pairwise_role_hlidx)
                     assert pairwise_role_idx > 0
                     one_gold_idxes[ef_idx, evt_idx] = pairwise_role_idx
     # get loss
     repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2)  # [*, len-ef, D]
     repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2)  # [*, len-evt, D]
     if np.prod(gold_idxes.shape) == 0:
         # no instances!
         return [[BK.zeros([]), BK.zeros([])]]
     else:
         gold_idxes_t = BK.input_idx(gold_idxes)
         return arg_linker.loss(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask,
                                gold_idxes_t, margin)
Example #11
0
 def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None):
     conf = self.conf
     # -----
     # special mode
     if conf.aug_word2 and conf.aug_word2_aug_encoder:
         _rop = RefreshOptions(training=False)  # special feature-mode!!
         self.embedder.refresh(_rop)
         self.encoder.refresh(_rop)
     # -----
     emb_t, mask_t = self.embedder(cur_input_map)
     rel_dist = cur_input_map.get("rel_dist", None)
     if rel_dist is not None:
         rel_dist = BK.input_idx(rel_dist)
     if conf.enc_choice == "vrec":
         enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss)
     elif conf.enc_choice == "original":  # todo(note): change back to arr for back compatibility
         assert rel_dist is None, "Original encoder does not support rel_dist"
         enc_t = self.encoder(emb_t, BK.get_value(mask_t))
         cache, enc_loss = None, None
     else:
         raise NotImplementedError()
     # another encoder based on attn
     final_enc_t = self.rpreper(emb_t, enc_t, cache)  # [*, slen, D] => final encoder output
     if conf.aug_word2:
         emb2_t = self.aug_word2(insts)
         if conf.aug_word2_aug_encoder:
             # simply add them all together, detach orig-enc as features
             stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach()
             features = self.aug_mixturer(stack_hidden_t)
             aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features))
             final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t,
                                                             rel_dist=rel_dist, collect_loss=collect_loss)
         else:
             final_enc_t = (final_enc_t + emb2_t)  # otherwise, simply adding
     return emb_t, mask_t, final_enc_t, cache, enc_loss
Example #12
0
 def batch_inputs_g1(self, insts: List[Sentence]):
     train_reverse_evetns = self.conf.train_reverse_evetns  # todo(note): this option is from derived class
     _tmp_f = lambda x: list(reversed(x)
                             ) if train_reverse_evetns else lambda x: x
     key, items_getter = self.extract_type, self.items_getter
     # nil_idx = 0  # nil means eos
     # get gold/input data and batch
     all_widxes, all_lidxes, all_vmasks, all_items, all_valid = [], [], [], [], []
     for sent in insts:
         preps = sent.preps.get(key)
         # not cached, rebuild them
         if preps is None:
             items = items_getter(sent)
             # todo(note): directly add, assume they are already sorted in a good way (widx+lidx); 0(nil) as eos
             if items is None:
                 prep_valid = 0.
                 # prep_widxes, prep_lidxes, prep_vmasks, prep_items = [0], [0], [1.], [None]
                 prep_widxes, prep_lidxes, prep_vmasks, prep_items = [], [], [], []
             else:
                 prep_valid = 1.
                 prep_widxes = _tmp_f(
                     [z.mention.hard_span.head_wid for z in items]) + [0]
                 prep_lidxes = _tmp_f(
                     [self.hlidx2idx(z.type_idx) for z in items]) + [0]
                 prep_vmasks = [1.] * (len(items) + 1)
                 prep_items = _tmp_f(items.copy()) + [None]
             sent.preps[key] = (prep_widxes, prep_lidxes, prep_vmasks,
                                prep_items, prep_valid)
         else:
             prep_widxes, prep_lidxes, prep_vmasks, prep_items, prep_valid = preps
         # =====
         all_widxes.append(prep_widxes)
         all_lidxes.append(prep_lidxes)
         all_vmasks.append(prep_vmasks)
         all_items.append(prep_items)
         all_valid.append(prep_valid)
     # pad and batch
     mention_widxes = BK.input_idx(
         self.padder_idxes.pad(all_widxes)[0])  # [*, ?]
     mention_lidxes = BK.input_idx(
         self.padder_idxes.pad(all_lidxes)[0])  # [*, ?]
     mention_vmasks = BK.input_real(
         self.padder_mask.pad(all_vmasks)[0])  # [*, ?]
     mention_items_arr, _ = self.padder_items.pad(all_items)  # [*, ?]
     mention_valid = BK.input_real(all_valid)  # [*]
     return mention_widxes, mention_lidxes, mention_vmasks, mention_items_arr, mention_valid
Example #13
0
 def init_oracle_mask(inst: ParseInstance, prev_arc_mask, prev_label):
     gold_heads = inst.heads.vals[1:]
     gold_labels = inst.labels.idxes[1:]
     gold_idxes = [i + 1 for i in range(len(gold_heads))]
     prev_arc_mask[gold_idxes, gold_heads] = 1.
     prev_label[gold_idxes,
                gold_heads] = BK.input_idx(gold_labels, BK.CPU_DEVICE)
     return prev_arc_mask, prev_label
Example #14
0
 def __call__(self, char_input, add_root_token: bool):
     char_input_t = BK.input_idx(char_input)  # [*, slen, wlen]
     if add_root_token:
         slice_shape = BK.get_shape(char_input_t)
         slice_shape[-2] = 1
         char_input_t0 = BK.constants(slice_shape, 0, dtype=char_input_t.dtype)  # todo(note): simply put 0 here!
         char_input_t1 = BK.concat([char_input_t0, char_input_t], -2)  # [*, 1?+slen, wlen]
     else:
         char_input_t1 = char_input_t
     char_embeds = self.E(char_input_t1)  # [*, 1?+slen, wlen, D]
     char_cat_expr = BK.concat([z(char_embeds) for z in self.char_cnns])
     return self.dropout(char_cat_expr)  # todo(note): only final dropout
Example #15
0
 def __call__(self, input_v, add_root_token: bool):
     if isinstance(input_v, np.ndarray):
         # direct use this [batch_size, slen] as input
         posi_idxes = BK.input_idx(input_v)
         expr = self.node(posi_idxes)  # [batch_size, slen, D]
     else:
         # input is a shape as prepared by "PosiHelper"
         batch_size, max_len = input_v
         if add_root_token:
             max_len += 1
         posi_idxes = BK.arange_idx(max_len)  # [1?+slen] add root=0 here
         expr = self.node(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1)
     return self.dropout(expr)
Example #16
0
 def pad_chs(self, idxes_list: List[List], labels_list: List[List]):
     start_posi = self.chs_start_posi
     if start_posi < 0:  # truncate
         idxes_list = [x[start_posi:] for x in idxes_list]
     # overall valid mask
     chs_valid = [(0. if len(z) == 0 else 1.) for z in idxes_list]
     # if any valid children in the batch
     if all(x > 0 for x in chs_valid):
         padded_chs_idxes, padded_chs_mask = self.ch_idx_padder.pad(
             idxes_list)  # [*, max-ch], [*, max-ch]
         if self.use_label_feat:
             if start_posi < 0:  # truncate
                 labels_list = [x[start_posi:] for x in labels_list]
             padded_chs_labels, _ = self.ch_label_padder.pad(
                 labels_list)  # [*, max-ch]
             chs_label_t = BK.input_idx(padded_chs_labels)
         else:
             chs_label_t = None
         chs_idxes_t, chs_mask_t, chs_valid_mask_t = \
             BK.input_idx(padded_chs_idxes), BK.input_real(padded_chs_mask), BK.input_real(chs_valid)
         return chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t
     else:
         return None, None, None, None
Example #17
0
 def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size):
     conf = self.conf
     free_mode = (force_widx is None)
     prev_state_h = prev_state[0]
     # =====
     # collect att scores
     key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)])  # [*, slen, h]
     query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)])  # [*, R, h]
     orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1))  # [*, slen, R]
     orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN  # [*, slen, R]
     # first maximum across the R dim (this step is hard max)
     maxr_scores, maxr_idxes = orig_scores.max(-1)  # [*, slen]
     if conf.zero_eos_score:
         # use mask to make it able to be backward
         tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.)
         tmp_mask.index_fill_(-1, BK.input_idx(0), 0.)
         maxr_scores *= tmp_mask
     # then select over the slen dim (this step is prob based)
     maxr_logprobs = BK.log_softmax(maxr_scores)  # [*, slen]
     if free_mode:
         cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1))
         sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False)  # [*, beam]
     else:
         sel_tok_idxes = force_widx.unsqueeze(-1)  # [*, 1]
         sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes)  # [*, 1]
     # then collect the info and perform labeling
     lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2)  # [*, ?, ~]
     lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1)  # [*, ?, 1]
     lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)]  # [*, ?, ~]  # todo(+3): using soft version?
     lf_prev_state = prev_state_h.unsqueeze(-2)  # [*, 1, ~]
     lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state])  # [*, ?, ~]
     # final predicting labels
     # todo(+N): here we select only max at labeling part, only beam at previous one
     if free_mode:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None)  # [*, ?]
     else:
         sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1))
     # no lab-logprob (*=0) for eos (sel_tok==0)
     sel_lab_logprobs *= (sel_tok_idxes>0).float()
     # compute next-state [*, ?, ~]
     # todo(note): here we flatten the first two dims
     tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1]
     tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1)
     tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1))
     tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1))
                       for z in prev_state]  # [*, ?, ?, D]
     next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None)
     next_state = [z.view(tmp_rnn_dims) for z in next_state]
     return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
Example #18
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     # iconf = self.conf.iconf
     with BK.no_grad_env():
         self.refresh_batch(False)
         # pruning and scores from g1
         valid_mask, go1_pack = self._get_g1_pack(
             insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing)
         # encode
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
             insts, False)
         mask_expr = BK.input_real(mask_arr)
         # decode
         final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
         ret_heads, ret_labels, _, _ = self.dl.decode(
             insts, enc_repr, final_valid_expr, go1_pack, False, 0.)
         # collect the results together
         all_heads = Helper.join_list(ret_heads)
         if ret_labels is None:
             # todo(note): simply get labels from the go1-label classifier; must provide g1parser
             if go1_pack is None:
                 _, go1_pack = self._get_g1_pack(insts, 1., 1.)
             _, go1_label_max_idxes = go1_pack[1].max(
                 -1)  # [bs, slen, slen]
             pred_heads_arr, _ = self.predict_padder.pad(
                 all_heads)  # [bs, slen]
             pred_heads_expr = BK.input_idx(pred_heads_arr)
             pred_labels_expr = BK.gather_one_lastdim(
                 go1_label_max_idxes, pred_heads_expr).squeeze(-1)
             all_labels = BK.get_value(pred_labels_expr)  # [bs, slen]
         else:
             all_labels = np.concatenate(ret_labels, 0)
         # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
         for one_idx, one_inst in enumerate(insts):
             cur_length = len(one_inst) + 1
             one_inst.pred_heads.set_vals(
                 all_heads[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 all_labels[one_idx][:cur_length], self.label_vocab)
             # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length])
         # =====
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         return info
Example #19
0
 def __call__(self, input, add_root_token: bool):
     voc = self.voc
     # todo(note): append a [cls/root] idx, currently use "bos"
     input_t = BK.input_idx(input)  # [*, 1+slen]
     # rare unk in training
     if self.rop.training and self.use_rare_unk:
         rare_unk_rate = self.ec_conf.comp_rare_unk
         cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long()
         input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask
     # root
     if add_root_token:
         input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype)  # [*, 1+slen]
         input_t_p1 = BK.concat([input_t_p0, input_t], -1)
     else:
         input_t_p1 = input_t
     expr = self.E(input_t_p1)  # [*, 1?+slen]
     return self.dropout(expr)
Example #20
0
 def calc_repr(s_enc: SL0Layer, features_group, enc_expr, bidxes_expr):
     cur_idxes, par_idxes, labels, chs_idxes, chs_labels = features_group
     # get padded idxes: [*] or [*, ?]
     cur_idxes_t = BK.input_idx(cur_idxes)
     par_idxes_t, label_t, par_mask_t = s_enc.pad_par(par_idxes, labels)
     chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t = s_enc.pad_chs(
         chs_idxes, chs_labels)
     # gather enc-expr: [*, D], [*, D], [*, max-chs, D]
     dim1_range_t = bidxes_expr
     dim2_range_t = dim1_range_t.unsqueeze(-1)
     cur_t = enc_expr[dim1_range_t, cur_idxes_t]
     par_t = enc_expr[dim1_range_t, par_idxes_t]
     chs_t = None if chs_idxes_t is None else enc_expr[dim2_range_t,
                                                       chs_idxes_t]
     # update reprs: [*, D]
     new_srepr = s_enc.calculate_repr(cur_t, par_t, label_t, par_mask_t,
                                      chs_t, chs_label_t, chs_mask_t,
                                      chs_valid_mask_t)
     return cur_idxes_t, new_srepr
Example #21
0
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr,
                                                 lengths_arr,
                                                 labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax,
                                                mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(
                mst_scores_arr)
Example #22
0
 def _decode(self, insts: List[ParseInstance], full_score, mask_expr,
             misc_prefix):
     # decode
     mst_lengths = [len(z) + 1
                    for z in insts]  # +=1 to include ROOT for mst decoding
     mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32)
     mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj(
         full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True)
     if self.conf.iconf.output_marginals:
         # todo(note): here, we care about marginals for arc
         # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True)
         arc_marginals = nmarginal_unproj(full_score,
                                          mask_expr,
                                          None,
                                          labeled=True).sum(-1)
         bsize, max_len = BK.get_shape(mask_expr)
         idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)
         idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0)
         output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr,
                                     BK.input_idx(mst_heads_arr)]
         mst_marg_arr = BK.get_value(output_marg)
     else:
         mst_marg_arr = None
     # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
     for one_idx, one_inst in enumerate(insts):
         cur_length = mst_lengths[one_idx]
         one_inst.pred_heads.set_vals(
             mst_heads_arr[one_idx]
             [:cur_length])  # directly int-val for heads
         one_inst.pred_labels.build_vals(
             mst_labels_arr[one_idx][:cur_length], self.label_vocab)
         one_scores = mst_scores_arr[one_idx][:cur_length]
         one_inst.pred_par_scores.set_vals(one_scores)
         # extra output
         one_inst.extra_pred_misc[misc_prefix +
                                  "_score"] = one_scores.tolist()
         if mst_marg_arr is not None:
             one_inst.extra_pred_misc[
                 misc_prefix +
                 "_marg"] = mst_marg_arr[one_idx][:cur_length].tolist()
Example #23
0
 def arange_cache(self, bidxes):
     new_bsize = len(bidxes)
     # if the idxes are already fine, then no need to select
     if not Helper.check_is_range(bidxes, self.cur_bsize):
         # mask is on CPU to make assigning easier
         bidxes_ct = BK.input_idx(bidxes, BK.CPU_DEVICE)
         self.scoring_fixed_mask_ct = self.scoring_fixed_mask_ct.index_select(
             0, bidxes_ct)
         self.scoring_mask_ct = self.scoring_mask_ct.index_select(
             0, bidxes_ct)
         self.oracle_mask_ct = self.oracle_mask_ct.index_select(
             0, bidxes_ct)
         # other things are all on target-device (possibly GPU)
         bidxes_device = BK.to_device(bidxes_ct)
         self.enc_repr = self.enc_repr.index_select(0, bidxes_device)
         self.scoring_cache.arange_cache(bidxes_device)
         # oracles
         self.oracle_mask_t = self.oracle_mask_t.index_select(
             0, bidxes_device)
         self.oracle_label_t = self.oracle_label_t.index_select(
             0, bidxes_device)
         # update bsize
         self.update_bsize(new_bsize)
Example #24
0
 def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
     conf = self.conf
     if self.add_root_token:
         repr_t = repr_t[:, 1:]
         mask_t = mask_t[:, 1:]
     # score
     scores_t = self._score(repr_t)  # [bs, rlen, D]
     # get gold
     gold_pidxes = np.zeros(BK.get_shape(mask_t),
                            dtype=np.long)  # [bs, ?+rlen]
     for bidx, inst in enumerate(insts):
         cur_seq_idxes = getattr(inst, self.attr_name).idxes
         gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes
     # get loss
     gold_pidxes_t = BK.input_idx(gold_pidxes)
     nll_loss_sum = self.neg_log_likelihood_loss(scores_t, mask_t.bool(),
                                                 gold_pidxes_t)
     if not conf.div_by_tok:  # otherwise div by sent
         nll_loss_sum *= (mask_t.sum() / len(insts))
     # compile loss
     crf_loss = LossHelper.compile_leaf_info("crf", nll_loss_sum.sum(),
                                             mask_t.sum())
     return self._compile_component_loss(self.pname, [crf_loss])
Example #25
0
 def refresh(self, rop=None):
     super().refresh(rop)
     # no need to fix0 for None since already done in the Embedding
     # refresh layered embeddings (in training, we should not be in no-grad mode)
     # todo(note): here, there can be dropouts
     layered_prei_arrs = self.hl_vocab.layered_prei
     layered_pool_links_padded_arrs = self.hl_vocab.layered_pool_links_padded
     layered_pool_links_mask_arrs = self.hl_vocab.layered_pool_links_mask
     layered_isnil = self.hl_vocab.layered_pool_isnil
     for i in range(self.max_layer):
         # [N, ?, D] -> [N, D] -> [D, N]
         self.layered_embeds_pred[i] = (
             BK.input_real(layered_pool_links_mask_arrs[i]).unsqueeze(-1) *
             self.pool_pred(
                 layered_pool_links_padded_arrs[i])).sum(-2).transpose(
                     0, 1).contiguous()
         # [N, ?, D] -> [N, D]
         self.layered_embeds_lookup[i] = (
             BK.input_real(layered_pool_links_mask_arrs[i]).unsqueeze(-1) *
             self.pool_lookup(layered_pool_links_padded_arrs[i])).sum(-2)
         # [?] of idxes/masks
         self.layered_prei[i] = BK.input_idx(layered_prei_arrs[i])
         self.layered_isnil[i] = BK.input_real(
             layered_isnil[i])  # is nil mask
Example #26
0
 def fb_on_batch(self,
                 annotated_insts,
                 training=True,
                 loss_factor=1,
                 **kwargs):
     self.refresh_batch(training)
     margin = self.margin.value
     # gold heads and labels
     gold_heads_arr, _ = self.predict_padder.pad(
         [z.heads.vals for z in annotated_insts])
     gold_labels_arr, _ = self.predict_padder.pad(
         [self.real2pred_labels(z.labels.idxes) for z in annotated_insts])
     gold_heads_expr = BK.input_idx(gold_heads_arr)  # [BS, Len]
     gold_labels_expr = BK.input_idx(gold_labels_arr)  # [BS, Len]
     # ===== calculate
     scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
         annotated_insts, training)
     full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                           training, margin,
                                           gold_heads_expr)
     #
     final_losses = None
     if self.norm_local or self.norm_single:
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         # already added margin previously
         losses_heads = losses_labels = None
         if self.loss_prob:
             if self.norm_local:
                 losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr)
                 losses_labels = BK.loss_nll(select_label_score,
                                             gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=False)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=False)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_hinge:
             if self.norm_local:
                 losses_heads = BK.loss_hinge(full_arc_score,
                                              gold_heads_expr)
                 losses_labels = BK.loss_hinge(select_label_score,
                                               gold_labels_expr)
             elif self.norm_single:
                 single_sample = self.conf.tconf.loss_single_sample
                 losses_heads = self._losses_single(full_arc_score,
                                                    gold_heads_expr,
                                                    single_sample,
                                                    is_hinge=True,
                                                    margin=margin)
                 losses_labels = self._losses_single(select_label_score,
                                                     gold_labels_expr,
                                                     single_sample,
                                                     is_hinge=True,
                                                     margin=margin)
             # simply adding
             final_losses = losses_heads + losses_labels
         elif self.loss_mr:
             # special treatment!
             probs_heads = BK.softmax(full_arc_score, dim=-1)  # [bs, m, h]
             probs_labels = BK.softmax(select_label_score,
                                       dim=-1)  # [bs, m, h]
             # select
             probs_head_gold = BK.gather_one_lastdim(
                 probs_heads, gold_heads_expr).squeeze(-1)  # [bs, m]
             probs_label_gold = BK.gather_one_lastdim(
                 probs_labels, gold_labels_expr).squeeze(-1)  # [bs, m]
             # root and pad will be excluded later
             # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions
             # todo(warn): have problem since steps will be quite small, not used!
             final_losses = (mask_expr - probs_head_gold * probs_label_gold
                             )  # let loss>=0
     elif self.norm_global:
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, training,
                                                   margin, gold_heads_expr,
                                                   gold_labels_expr)
         # for this one, use the merged full score
         full_score = full_arc_score.unsqueeze(
             -1) + full_label_score  # [BS, m, h, L]
         # +=1 to include ROOT for mst decoding
         mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts],
                                      dtype=np.int32)
         # do inference
         if self.loss_prob:
             marginals_expr = self._marginal(
                 full_score, mask_expr, mst_lengths_arr)  # [BS, m, h, L]
             final_losses = self._losses_global_prob(
                 full_score, gold_heads_expr, gold_labels_expr,
                 marginals_expr, mask_expr)
             if self.alg_proj:
                 # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg),
                 #  but this might be too loose, although the unproj edges are few?
                 gold_unproj_arr, _ = self.predict_padder.pad(
                     [z.unprojs for z in annotated_insts])
                 gold_unproj_expr = BK.input_real(
                     gold_unproj_arr)  # [BS, Len]
                 comparing_expr = Constants.REAL_PRAC_MIN * (
                     1. - gold_unproj_expr)
                 final_losses = BK.max_elem(final_losses, comparing_expr)
         elif self.loss_hinge:
             pred_heads_arr, pred_labels_arr, _ = self._decode(
                 full_score, mask_expr, mst_lengths_arr)
             pred_heads_expr = BK.input_idx(pred_heads_arr)  # [BS, Len]
             pred_labels_expr = BK.input_idx(pred_labels_arr)  # [BS, Len]
             #
             final_losses = self._losses_global_hinge(
                 full_score, gold_heads_expr, gold_labels_expr,
                 pred_heads_expr, pred_labels_expr, mask_expr)
         elif self.loss_mr:
             # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges
             raise NotImplementedError(
                 "Not implemented for global-loss + mr.")
     elif self.norm_hlocal:
         # firstly label losses are the same
         select_label_score = self._score_label_selected(
             scoring_expr_pack, mask_expr, training, margin,
             gold_heads_expr, gold_labels_expr)
         losses_labels = BK.loss_nll(select_label_score, gold_labels_expr)
         # then specially for arc loss
         children_masks_arr, _ = self.hlocal_padder.pad(
             [z.get_children_mask_arr() for z in annotated_insts])
         children_masks_expr = BK.input_real(
             children_masks_arr)  # [bs, h, m]
         # [bs, h]
         # todo(warn): use prod rather than sum, but still only an approximation for the top-down
         # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr))
         losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose(
             -1, -2) * children_masks_expr,
                              dim=-1)
         # including the root-head is important
         losses_arc[:, 1] += losses_arc[:, 0]
         final_losses = losses_arc + losses_labels
     #
     # jpos loss? (the same mask as parsing)
     jpos_losses_expr = jpos_pack[1]
     if jpos_losses_expr is not None:
         final_losses += jpos_losses_expr
     # collect loss with mask, also excluding the first symbol of ROOT
     final_losses_masked = (final_losses * mask_expr)[:, 1:]
     final_loss_sum = BK.sum(final_losses_masked)
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_tok = sum(len(z) for z in annotated_insts)
     if self.conf.tconf.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     #
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     if training:
         BK.backward(final_loss, loss_factor)
     return info
Example #27
0
 def _decode(self, mb_insts: List[ParseInstance], mb_enc_expr,
             mb_valid_expr, mb_go1_pack, training: bool, margin: float):
     # =====
     use_sib, use_gp = self.use_sib, self.use_gp
     # =====
     mb_size = len(mb_insts)
     mat_shape = BK.get_shape(mb_valid_expr)
     max_slen = mat_shape[-1]
     # step 1: extract the candidate features
     batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes = self.helper.get_cand_features(
         mb_valid_expr)
     # =====
     # step 2: high order scoring
     # step 2.1: basic scoring, [*], [*, Lab]
     arc_scores, lab_scores = self._get_basic_score(mb_enc_expr,
                                                    batch_idxes, m_idxes,
                                                    h_idxes, sib_idxes,
                                                    gp_idxes)
     cur_system_labeled = (lab_scores is not None)
     # step 2.2: margin
     # get gold labels, which can be useful for later calculating loss
     if training:
         gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = \
             [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_insts(mb_insts, use_sib, use_gp)]
         # add the margins to the scores: (m,h), (m,sib), (m,gp)
         cur_margin = margin / self.margin_div
         self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes,
                                   gold_h_idxes, gold_lab_idxes,
                                   batch_idxes, m_idxes, h_idxes,
                                   arc_scores, lab_scores, cur_margin,
                                   cur_margin)
         if use_sib:
             self._add_margin_inplaced(mat_shape, gold_b_idxes,
                                       gold_m_idxes, gold_sib_idxes,
                                       gold_lab_idxes, batch_idxes, m_idxes,
                                       sib_idxes, arc_scores, lab_scores,
                                       cur_margin, cur_margin)
         if use_gp:
             self._add_margin_inplaced(mat_shape, gold_b_idxes,
                                       gold_m_idxes, gold_gp_idxes,
                                       gold_lab_idxes, batch_idxes, m_idxes,
                                       gp_idxes, arc_scores, lab_scores,
                                       cur_margin, cur_margin)
         # may be useful for later training
         gold_pack = (mb_size, gold_b_idxes, gold_m_idxes, gold_h_idxes,
                      gold_sib_idxes, gold_gp_idxes, gold_lab_idxes)
     else:
         gold_pack = None
     # step 2.3: o1scores
     if mb_go1_pack is not None:
         go1_arc_scores, go1_lab_scores = mb_go1_pack
         # todo(note): go1_arc_scores is not added here, but as the input to the dec-algo
         if cur_system_labeled:
             lab_scores += go1_lab_scores[batch_idxes, m_idxes, h_idxes]
     else:
         go1_arc_scores = None
     # step 2.4: max out labels; todo(+N): or using logsumexp here?
     if cur_system_labeled:
         max_lab_scores, max_lab_idxes = lab_scores.max(-1)
         final_scores = arc_scores + max_lab_scores  # [*], final input arc scores
     else:
         max_lab_idxes = None
         final_scores = arc_scores
     # =====
     # step 3: actual decode
     res_heads = []
     for sid, inst in enumerate(mb_insts):
         slen = len(inst) + 1  # plus one for the art-root
         arr_o1_masks = BK.get_value(mb_valid_expr[sid, :slen, :slen].int())
         arr_o1_scores = BK.get_value(
             go1_arc_scores[sid, :slen, :slen].double()) if (
                 go1_arc_scores is not None) else None
         cur_bidx_mask = (batch_idxes == sid)
         input_pack = [m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores]
         one_heads = self.helper.decode_one(slen, self.projective,
                                            arr_o1_masks, arr_o1_scores,
                                            input_pack, cur_bidx_mask)
         res_heads.append(one_heads)
     # =====
     # step 4: get labels back and pred_pack
     pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, _ = \
         [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_preds(res_heads, None, use_sib, use_gp)]
     if cur_system_labeled:
         # obtain hit components
         pred_hit_mask = self._get_hit_mask(mat_shape, pred_b_idxes,
                                            pred_m_idxes, pred_h_idxes,
                                            batch_idxes, m_idxes, h_idxes)
         if use_sib:
             pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes,
                                                 pred_m_idxes,
                                                 pred_sib_idxes,
                                                 batch_idxes, m_idxes,
                                                 sib_idxes)
         if use_gp:
             pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes,
                                                 pred_m_idxes,
                                                 pred_gp_idxes, batch_idxes,
                                                 m_idxes, gp_idxes)
         # get pred labels (there should be only one hit per mod!)
         pred_labels = BK.constants_idx([mb_size, max_slen], 0)
         pred_labels[batch_idxes[pred_hit_mask],
                     m_idxes[pred_hit_mask]] = max_lab_idxes[pred_hit_mask]
         res_labels = BK.get_value(pred_labels)
         pred_lab_idxes = pred_labels[pred_b_idxes, pred_m_idxes]
     else:
         res_labels = None
         pred_lab_idxes = None
     pred_pack = (mb_size, pred_b_idxes, pred_m_idxes, pred_h_idxes,
                  pred_sib_idxes, pred_gp_idxes, pred_lab_idxes)
     # return
     return res_heads, res_labels, gold_pack, pred_pack
Example #28
0
 def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1.,
                 rand_gen=None, assign_attns=False, **kwargs):
     # =====
     # import torch
     # torch.autograd.set_detect_anomaly(True)
     # with torch.autograd.detect_anomaly():
     # =====
     conf = self.conf
     self.refresh_batch(training)
     if len(insts) == 0:
         return {"fb": 0, "sent": 0, "tok": 0}
     # -----
     # copying instances for training: expand at dim0
     cur_copy = conf.train_inst_copy if training else 1
     copied_insts = insts * cur_copy
     all_losses = []
     # -----
     # original input
     input_map = self.inputter(copied_insts)
     # for the pretraining modules
     has_loss_mlm, has_loss_orp = (self.masklm.loss_lambda.value > 0.), (self.orderpr.loss_lambda.value > 0.)
     if (not has_loss_orp) and has_loss_mlm:  # only for mlm
         masked_input_map, input_erase_mask_arr = self.masklm.mask_input(input_map, rand_gen=rand_gen)
         emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(masked_input_map, collect_loss=True)
         all_losses.append(enc_loss)
         # mlm loss; todo(note): currently only using one layer
         mlm_loss = self.masklm.loss(enc_t, input_erase_mask_arr, input_map)
         all_losses.append(mlm_loss)
         # assign values
         if assign_attns:  # may repeat and only keep that last one, but does not matter!
             self._assign_attns_item(copied_insts, "mask", input_erase_mask_arr=input_erase_mask_arr, cache=cache)
         # agreement loss
         if cur_copy > 1:
             all_losses.extend(self._get_agr_loss("agr_mlm", cache, copy_num=cur_copy))
     if has_loss_orp:
         disturbed_input_map = self.orderpr.disturb_input(input_map, rand_gen=rand_gen)
         if has_loss_mlm:  # further mask some
             disturb_keep_arr = disturbed_input_map.get("disturb_keep", None)
             assert disturb_keep_arr is not None, "No keep region for mlm!"
             # todo(note): in this mode we assume add_root, so here exclude arti-root by [:,1:]
             masked_input_map, input_erase_mask_arr = \
                 self.masklm.mask_input(input_map, rand_gen=rand_gen, extra_mask_arr=disturb_keep_arr[:,1:])
             disturbed_input_map.update(masked_input_map)  # update
         emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(disturbed_input_map, collect_loss=True)
         all_losses.append(enc_loss)
         # orp loss
         if conf.orp_loss_special:
             orp_loss = self.orderpr.loss_special(enc_t, mask_t, disturbed_input_map.get("disturb_keep", None),
                                                  disturbed_input_map, self.masklm)
         else:
             orp_input_attn = self.prepr_f(cache, disturbed_input_map.get("rel_dist"))
             orp_loss = self.orderpr.loss(enc_t, orp_input_attn, mask_t, disturbed_input_map.get("disturb_keep", None))
         all_losses.append(orp_loss)
         # mlm loss
         if has_loss_mlm:
             mlm_loss = self.masklm.loss(enc_t, input_erase_mask_arr, input_map)
             all_losses.append(mlm_loss)
         # assign values
         if assign_attns:  # may repeat and only keep that last one, but does not matter!
             self._assign_attns_item(copied_insts, "dist", abs_posi_arr=disturbed_input_map.get("posi"), cache=cache)
         # agreement loss
         if cur_copy > 1:
             all_losses.extend(self._get_agr_loss("agr_orp", cache, copy_num=cur_copy))
     if self.plainlm.loss_lambda.value > 0.:
         if conf.enc_choice == "vrec":  # special case for blm
             emb_t, mask_t = self.embedder(input_map)
             rel_dist = input_map.get("rel_dist", None)
             if rel_dist is not None:
                 rel_dist = BK.input_idx(rel_dist)
             # two directions
             true_rel_dist = self._get_rel_dist(BK.get_shape(mask_t, -1))  # q-k: [len_q, len_k]
             enc_t1, cache1, enc_loss1 = self.encoder(emb_t, src_mask=mask_t, qk_mask=(true_rel_dist<=0).float(),
                                                      rel_dist=rel_dist, collect_loss=True)
             enc_t2, cache2, enc_loss2 = self.encoder(emb_t, src_mask=mask_t, qk_mask=(true_rel_dist>=0).float(),
                                                      rel_dist=rel_dist, collect_loss=True)
             assert not self.rpreper.active, "TODO: Not supported for this mode"
             all_losses.extend([enc_loss1, enc_loss2])
             # plm loss with explict two inputs
             plm_loss = self.plainlm.loss([enc_t1, enc_t2], input_map)
             all_losses.append(plm_loss)
         else:
             # here use original input
             emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(input_map, collect_loss=True)
             all_losses.append(enc_loss)
             # plm loss
             plm_loss = self.plainlm.loss(enc_t, input_map)
             all_losses.append(plm_loss)
         # agreement loss
         assert self.lambda_agree.value==0., "Not implemented for this mode"
     # =====
     # task loss
     dpar_loss_lambda, upos_loss_lambda, ner_loss_lambda = \
         [0. if z is None else z.loss_lambda.value for z in [self.dpar, self.upos, self.ner]]
     if any(z>0. for z in [dpar_loss_lambda, upos_loss_lambda, ner_loss_lambda]):
         # here use original input
         emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(input_map, collect_loss=True, insts=insts)
         all_losses.append(enc_loss)
         # parsing loss
         if dpar_loss_lambda > 0.:
             dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1)))
             dpar_loss = self.dpar.loss(copied_insts, enc_t, dpar_input_attn, mask_t)
             all_losses.append(dpar_loss)
         # pos loss
         if upos_loss_lambda > 0.:
             upos_loss = self.upos.loss(copied_insts, enc_t, mask_t)
             all_losses.append(upos_loss)
         # ner loss
         if ner_loss_lambda > 0.:
             ner_loss = self.ner.loss(copied_insts, enc_t, mask_t)
             all_losses.append(ner_loss)
     # -----
     info = self.collect_loss_and_backward(all_losses, training, loss_factor)
     info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)})
     return info
Example #29
0
 def _get_rel_dist_embed(self, rel_dist, use_abs: bool):
     if use_abs:
         rel_dist = BK.input_idx(rel_dist).abs()
     ret = self.rel_dist_embed(rel_dist)  # [bs, len, len, H]
     return ret
Example #30
0
 def loss(self, repr_t, orig_map: Dict, **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # --
     # specify input
     add_root_token = self.add_root_token
     # get from inputs
     if isinstance(repr_t, (list, tuple)):
         l2r_repr_t, r2l_repr_t = repr_t
     elif self.split_input_blm:
         l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1)
     else:
         l2r_repr_t, r2l_repr_t = repr_t, None
     # l2r and r2l
     word_t = BK.input_idx(orig_map["word"])  # [bs, rlen]
     slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long()  # [bs, 1]
     if add_root_token:
         l2r_trg_t = BK.concat([word_t, slice_zero_t],
                               -1)  # pad one extra 0, [bs, rlen+1]
         r2l_trg_t = BK.concat(
             [slice_zero_t, slice_zero_t, word_t[:, :-1]],
             -1)  # pad two extra 0 at front, [bs, 2+rlen-1]
     else:
         l2r_trg_t = BK.concat(
             [word_t[:, 1:], slice_zero_t], -1
         )  # pad one extra 0, but remove the first one, [bs, -1+rlen+1]
         r2l_trg_t = BK.concat(
             [slice_zero_t, word_t[:, :-1]],
             -1)  # pad one extra 0 at front, [bs, 1+rlen-1]
     # gather the losses
     all_losses = []
     pred_range_min, pred_range_max = max(
         1, conf.min_pred_rank), self.pred_size - 1
     if _tie_input_embeddings:
         pred_W = self.inputter_embed_node.E.E[:self.
                                               pred_size]  # [PSize, Dim]
     else:
         pred_W = None
     # get input embeddings for output
     for pred_name, hid_node, pred_node, input_t, trg_t in \
                 zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred],
                     [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]):
         if input_t is None:
             continue
         # hidden
         hid_t = hid_node(
             input_t) if hid_node else input_t  # [bs, slen, hid]
         # pred: [bs, slen, Vsize]
         if _tie_input_embeddings:
             scores_t = BK.matmul(hid_t, pred_W.T)
         else:
             scores_t = pred_node(hid_t)
         # loss
         mask_t = ((trg_t >= pred_range_min) &
                   (trg_t <= pred_range_max)).float()  # [bs, slen]
         trg_t.clamp_(max=pred_range_max)  # make it in range
         losses_t = BK.loss_nll(scores_t, trg_t) * mask_t  # [bs, slen]
         _, argmax_idxes = scores_t.max(-1)  # [bs, slen]
         corrs_t = (argmax_idxes == trg_t).float() * mask_t  # [bs, slen]
         # compile leaf loss
         one_loss = LossHelper.compile_leaf_info(pred_name,
                                                 losses_t.sum(),
                                                 mask_t.sum(),
                                                 loss_lambda=1.,
                                                 corr=corrs_t.sum())
         all_losses.append(one_loss)
     return self._compile_component_loss("plm", all_losses)