def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): conf = self.conf # scoring arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] # loss bsize, max_len = BK.get_shape(mask_expr) # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in insts]) # todo(note): here use the original idx of label, no shift! gold_labels_arr, _ = self.predict_padder.pad( [z.labels.idxes for z in insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [bs, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [bs, Len] # collect the losses arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) # [bs, 1] arange_m_expr = BK.arange_idx(max_len).unsqueeze(0) # [1, Len] # logsoftmax and losses arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1), -1) # [bs, m, h] lab_logsoftmaxs = BK.log_softmax(lab_score, -1) # [bs, m, h, Lab] arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr] # [bs, Len] lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr, gold_labels_expr] # [bs, Len] # head selection (no root) arc_loss_sum = (-arc_sel_ls * mask_expr)[:, 1:].sum() lab_loss_sum = (-lab_sel_ls * mask_expr)[:, 1:].sum() final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum final_loss_count = mask_expr[:, 1:].sum() return [[final_loss, final_loss_count]]
def batch_inputs_h(self, insts: List[Sentence]): key, items_getter = self.extract_type, self.items_getter nil_idx = 0 # get gold/input data and batch all_masks, all_idxes, all_items, all_valid = [], [], [], [] all_idxes2, all_items2 = [], [] # secondary types for sent in insts: preps = sent.preps.get(key) # not cached, rebuild them if preps is None: length = sent.length items = items_getter(sent) # token-idx -> ... prep_masks, prep_idxes, prep_items = [0.] * length, [ nil_idx ] * length, [None] * length prep_idxes2, prep_items2 = [nil_idx] * length, [None] * length if items is None: # todo(note): there are samples that do not have entity annotations (KBP15) # final 0/1 indicates valid or not prep_valid = 0. else: prep_valid = 1. for one_item in items: this_hwidx = one_item.mention.hard_span.head_wid this_hlidx = one_item.type_idx # todo(+N): ignore except the first two types (already ranked by type-freq) if prep_idxes[this_hwidx] == 0: prep_masks[this_hwidx] = 1. prep_idxes[this_hwidx] = self.hlidx2idx( this_hlidx) # change to int here! prep_items[this_hwidx] = one_item elif prep_idxes2[this_hwidx] == 0: prep_idxes2[this_hwidx] = self.hlidx2idx( this_hlidx) # change to int here! prep_items2[this_hwidx] = one_item sent.preps[key] = (prep_masks, prep_idxes, prep_items, prep_valid, prep_idxes2, prep_items2) else: prep_masks, prep_idxes, prep_items, prep_valid, prep_idxes2, prep_items2 = preps # ===== all_masks.append(prep_masks) all_idxes.append(prep_idxes) all_items.append(prep_items) all_valid.append(prep_valid) all_idxes2.append(prep_idxes2) all_items2.append(prep_items2) # pad and batch mention_masks = BK.input_real( self.padder_mask.pad(all_masks)[0]) # [*, slen] mention_idxes = BK.input_idx( self.padder_idxes.pad(all_idxes)[0]) # [*, slen] mention_items_arr, _ = self.padder_items.pad(all_items) # [*, slen] mention_valid = BK.input_real(all_valid) # [*] mention_idxes2 = BK.input_idx( self.padder_idxes.pad(all_idxes2)[0]) # [*, slen] mention_items2_arr, _ = self.padder_items.pad(all_items2) # [*, slen] return mention_masks, mention_idxes, mention_items_arr, mention_valid, mention_idxes2, mention_items2_arr
def pad_par(self, idxes: List, labels: List): par_idxes_t = BK.input_idx(idxes) labels_t = BK.input_idx(labels) # todo(note): specifically, <0 means non-exist # todo(note): an interesting bug, the bug is ">=" was wrongly written as "<", in this way, 0 will act as the parent of those who actually do not have parents and are to be attached, therefore maybe patterns of "parent=0" will get much positive scores # todo(note): ACTUALLY, mainly because of the difference in search and forward-backward!! par_mask_t = (par_idxes_t >= 0).float() par_idxes_t.clamp_(0) # since -1 will be illegal idx labels_t.clamp_(0) return par_idxes_t, labels_t, par_mask_t
def _collect_insts(self, ms_items: List, training): max_range = self.conf.max_range ret_efs, ret_sents, ret_bidxes, ret_head_idxes, ret_left_dists, ret_right_dists = [], [], [], [], [], [] for batch_idx, one_item in enumerate(ms_items): one_sents = one_item.sents sid2sents = {s.sid: s for s in one_sents} # not the sid in this list sid2offsets = { s.sid: v for s, v in zip(one_sents, one_item.offsets) } # not the sid in this list # assert one_sents[0].sid == 0, "Currently only support fake doc!" one_center_idx = one_item.center_idx one_center_sent = one_sents[one_center_idx] # get target events one_center_evts = one_center_sent.events if training else one_center_sent.pred_events if one_center_evts is not None and len(one_center_evts) > 0: # todo(+N): is multi-event ok? # assert len(one_center_evts) == 1, "Currently only support one event at one sent!!" # get args for one_center_evt in one_center_evts: if one_center_evt.links is None: continue for one_arg in one_center_evt.links: one_ef = one_arg.ef # only collect in-ranged ones if one_ef.mention is not None and one_ef.mention.hard_span.sid in sid2sents: hspan = one_ef.mention.hard_span sid, head_wid, wid, wlen = hspan.sid, hspan.head_wid, hspan.wid, hspan.length left_dist = head_wid - wid right_dist = wid + wlen - 1 - head_wid if training: if left_dist >= max_range or right_dist >= max_range: continue # skip long spans in training else: # clear wid and wlen for testing hspan.wid = hspan.head_wid hspan.length = 1 left_dist = right_dist = 0 # add one ret_sents.append(sid2sents[sid]) ret_efs.append( one_ef ) # todo(note): may repeat but does not matter ret_bidxes.append(batch_idx) ret_head_idxes.append(sid2offsets[sid] + head_wid - 1) # minus ROOT offset ret_left_dists.append(left_dist) ret_right_dists.append(right_dist) return ret_efs, ret_sents, BK.input_idx(ret_bidxes), BK.input_idx(ret_head_idxes), \ BK.input_idx(ret_left_dists), BK.input_idx(ret_right_dists)
def run_sents(self, all_sents: List, all_docs: List[DocInstance], training: bool, use_one_bucket=False): if use_one_bucket: all_buckets = [all_sents] # when we do not want to split if we know the input lengths do not vary too much else: all_sents.sort(key=lambda x: x[0].length) all_buckets = self._bucket_sents_by_length(all_sents, self.bconf.enc_bucket_range) # doc hint use_doc_hint = self.use_doc_hint if use_doc_hint: dh_sent_repr = self.dh_node.run(all_docs) # [NumDoc, MaxSent, D] else: dh_sent_repr = None # encoding for each of the bucket rets = [] dh_add, dh_both, dh_cls = [self.dh_combine_method==z for z in ["add", "both", "cls"]] for one_bucket in all_buckets: one_sents = [z[0] for z in one_bucket] # [BS, Len, Di], [BS, Len] input_repr0, mask_arr0 = self._prepare_input(one_sents, training) if use_doc_hint: one_d_idxes = BK.input_idx([z[1] for z in one_bucket]) one_s_idxes = BK.input_idx([z[2] for z in one_bucket]) one_s_reprs = dh_sent_repr[one_d_idxes, one_s_idxes].unsqueeze(-2) # [BS, 1, D] if dh_add: input_repr = input_repr0 + one_s_reprs # [BS, slen, D] mask_arr = mask_arr0 elif dh_both: input_repr = BK.concat([one_s_reprs, input_repr0, one_s_reprs], -2) # [BS, 2+slen, D] mask_arr = np.pad(mask_arr0, ((0,0),(1,1)), 'constant', constant_values=1.) # [BS, 2+slen] elif dh_cls: input_repr = BK.concat([one_s_reprs, input_repr0[:, 1:]], -2) # [BS, slen, D] mask_arr = mask_arr0 else: raise NotImplementedError() else: input_repr, mask_arr = input_repr0, mask_arr0 # [BS, Len, De] enc_repr = self.enc(input_repr, mask_arr) # separate ones (possibly using detach to avoid gradients for some of them) enc_repr_ef = self.enc_ef(enc_repr.detach() if self.bconf.enc_ef_input_detach else enc_repr, mask_arr) enc_repr_evt = self.enc_evt(enc_repr.detach() if self.bconf.enc_evt_input_detach else enc_repr, mask_arr) if use_doc_hint and dh_both: one_ret = (one_sents, input_repr0, enc_repr_ef[:, 1:-1].contiguous(), enc_repr_evt[:, 1:-1].contiguous(), mask_arr0) else: one_ret = (one_sents, input_repr0, enc_repr_ef, enc_repr_evt, mask_arr0) rets.append(one_ret) # todo(note): returning tuple is (List[Sentence], Tensor, Tensor, Tensor) return rets
def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): conf = self.conf # score scores_t = self._score(repr_t) # [bs, ?+rlen, D] # get gold gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long) # [bs, ?+rlen] for bidx, inst in enumerate(insts): cur_seq_idxes = getattr(inst, self.attr_name).idxes if self.add_root_token: gold_pidxes[bidx, 1:1 + len(cur_seq_idxes)] = cur_seq_idxes else: gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes # get loss margin = self.margin.value gold_pidxes_t = BK.input_idx(gold_pidxes) gold_pidxes_t *= (gold_pidxes_t < self.pred_out_dim).long() # 0 means invalid ones!! loss_mask_t = (gold_pidxes_t > 0).float() * mask_t # [bs, ?+rlen] lab_losses_t = BK.loss_nll(scores_t, gold_pidxes_t, margin=margin) # [bs, ?+rlen] # argmax _, argmax_idxes = scores_t.max(-1) pred_corrs = (argmax_idxes == gold_pidxes_t).float() * loss_mask_t # compile loss lab_loss = LossHelper.compile_leaf_info("slab", lab_losses_t.sum(), loss_mask_t.sum(), corr=pred_corrs.sum()) return self._compile_component_loss(self.pname, [lab_loss])
def loss(self, ms_items: List, bert_expr, basic_expr): conf = self.conf bsize = len(ms_items) # use gold targets: only use positive samples!! offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.events, True, False, 0., 0., True) # [bs, ?] realis_flist = [(-1 if (z is None or z.realis_idx is None) else z.realis_idx) for z in items_arr.flatten()] realis_t = BK.input_idx(realis_flist).view(items_arr.shape) # [bs, ?] realis_mask = (realis_t >= 0).float() realis_t.clamp_(min=0) # make sure all idxes are legal # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] # realis, types # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build losses loss_item_realis = self._get_one_loss(self.realis_predictor, hiddens, realis_t, realis_mask, conf.lambda_realis) loss_item_type = self._get_one_loss(self.type_predictor, hiddens, labels_t, masks_t, conf.lambda_type) return [loss_item_realis, loss_item_type]
def _loss(self, enc_repr, action_list: List[EfAction], arc_weight_list: List[float], label_weight_list: List[float], bidxes_list: List[int]): # 1. collect (batched) features; todo(note): use prev state for scoring hm_features = self.hm_feature_getter.get_hm_features(action_list, [a.state_from for a in action_list]) # 2. get new sreprs scorer = self.scorer s_enc = self.slayer bsize_range_t = BK.input_idx(bidxes_list) node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], enc_repr, bsize_range_t) node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], enc_repr, bsize_range_t) # label loss if self.system_labeled: node_lh_expr, _ = scorer.transform_space_label(node_h_srepr, True, False) _, node_lm_pack = scorer.transform_space_label(node_m_srepr, False, True) label_scores_full = scorer.score_label(node_lm_pack, node_lh_expr) # [*, Lab] label_scores = BK.gather_one_lastdim(label_scores_full, [a.label for a in action_list]).squeeze(-1) final_label_loss_sum = (label_scores * BK.input_real(label_weight_list)).sum() else: label_scores = final_label_loss_sum = BK.zeros([]) # arc loss node_ah_expr, _ = scorer.transform_space_arc(node_h_srepr, True, False) _, node_am_pack = scorer.transform_space_arc(node_m_srepr, False, True) arc_scores = scorer.score_arc(node_am_pack, node_ah_expr).squeeze(-1) final_arc_loss_sum = (arc_scores * BK.input_real(arc_weight_list)).sum() # score reg return final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores
def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr): mask_idxes, mask_valids = BK.mask2idx( BK.input_real(pred_mask_repl_arr)) # [bsize, ?] if BK.get_shape(mask_idxes, -1) == 0: # no loss zzz = BK.zeros([]) return [[zzz, zzz, zzz]] else: target_reprs = BK.gather_first_dims(repr_t, mask_idxes, 1) # [bsize, ?, *] target_hids = self.hid_layer(target_reprs) target_scores = self.pred_layer(target_hids) # [bsize, ?, V] pred_idx_t = BK.input_idx(pred_idx_arr) # [bsize, slen] target_idx_t = pred_idx_t.gather(-1, mask_idxes) # [bsize, ?] target_idx_t[(mask_valids < 1.)] = 0 # make sure invalid ones in range # get loss pred_losses = BK.loss_nll(target_scores, target_idx_t) # [bsize, ?] pred_loss_sum = (pred_losses * mask_valids).sum() pred_loss_count = mask_valids.sum() # argmax _, argmax_idxes = target_scores.max(-1) pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids pred_corr_count = pred_corrs.sum() return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
def _fb_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt, margin): # get the gold idxes arg_linker = self.arg_linker bsize, len_ef = ef_items.shape bsize2, len_evt = evt_items.shape assert bsize == bsize2 gold_idxes = np.zeros([bsize, len_ef, len_evt], dtype=np.long) for one_gold_idxes, one_ef_items, one_evt_items in zip(gold_idxes, ef_items, evt_items): # todo(note): check each pair for ef_idx, one_ef in enumerate(one_ef_items): if one_ef is None: continue role_map = {id(z.evt): z.role_idx for z in one_ef.links} # todo(note): since we get the original linked ones for evt_idx, one_evt in enumerate(one_evt_items): pairwise_role_hlidx = role_map.get(id(one_evt)) if pairwise_role_hlidx is not None: pairwise_role_idx = arg_linker.hlidx2idx(pairwise_role_hlidx) assert pairwise_role_idx > 0 one_gold_idxes[ef_idx, evt_idx] = pairwise_role_idx # get loss repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2) # [*, len-ef, D] repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2) # [*, len-evt, D] if np.prod(gold_idxes.shape) == 0: # no instances! return [[BK.zeros([]), BK.zeros([])]] else: gold_idxes_t = BK.input_idx(gold_idxes) return arg_linker.loss(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask, gold_idxes_t, margin)
def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None): conf = self.conf # ----- # special mode if conf.aug_word2 and conf.aug_word2_aug_encoder: _rop = RefreshOptions(training=False) # special feature-mode!! self.embedder.refresh(_rop) self.encoder.refresh(_rop) # ----- emb_t, mask_t = self.embedder(cur_input_map) rel_dist = cur_input_map.get("rel_dist", None) if rel_dist is not None: rel_dist = BK.input_idx(rel_dist) if conf.enc_choice == "vrec": enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) elif conf.enc_choice == "original": # todo(note): change back to arr for back compatibility assert rel_dist is None, "Original encoder does not support rel_dist" enc_t = self.encoder(emb_t, BK.get_value(mask_t)) cache, enc_loss = None, None else: raise NotImplementedError() # another encoder based on attn final_enc_t = self.rpreper(emb_t, enc_t, cache) # [*, slen, D] => final encoder output if conf.aug_word2: emb2_t = self.aug_word2(insts) if conf.aug_word2_aug_encoder: # simply add them all together, detach orig-enc as features stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach() features = self.aug_mixturer(stack_hidden_t) aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features)) final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) else: final_enc_t = (final_enc_t + emb2_t) # otherwise, simply adding return emb_t, mask_t, final_enc_t, cache, enc_loss
def batch_inputs_g1(self, insts: List[Sentence]): train_reverse_evetns = self.conf.train_reverse_evetns # todo(note): this option is from derived class _tmp_f = lambda x: list(reversed(x) ) if train_reverse_evetns else lambda x: x key, items_getter = self.extract_type, self.items_getter # nil_idx = 0 # nil means eos # get gold/input data and batch all_widxes, all_lidxes, all_vmasks, all_items, all_valid = [], [], [], [], [] for sent in insts: preps = sent.preps.get(key) # not cached, rebuild them if preps is None: items = items_getter(sent) # todo(note): directly add, assume they are already sorted in a good way (widx+lidx); 0(nil) as eos if items is None: prep_valid = 0. # prep_widxes, prep_lidxes, prep_vmasks, prep_items = [0], [0], [1.], [None] prep_widxes, prep_lidxes, prep_vmasks, prep_items = [], [], [], [] else: prep_valid = 1. prep_widxes = _tmp_f( [z.mention.hard_span.head_wid for z in items]) + [0] prep_lidxes = _tmp_f( [self.hlidx2idx(z.type_idx) for z in items]) + [0] prep_vmasks = [1.] * (len(items) + 1) prep_items = _tmp_f(items.copy()) + [None] sent.preps[key] = (prep_widxes, prep_lidxes, prep_vmasks, prep_items, prep_valid) else: prep_widxes, prep_lidxes, prep_vmasks, prep_items, prep_valid = preps # ===== all_widxes.append(prep_widxes) all_lidxes.append(prep_lidxes) all_vmasks.append(prep_vmasks) all_items.append(prep_items) all_valid.append(prep_valid) # pad and batch mention_widxes = BK.input_idx( self.padder_idxes.pad(all_widxes)[0]) # [*, ?] mention_lidxes = BK.input_idx( self.padder_idxes.pad(all_lidxes)[0]) # [*, ?] mention_vmasks = BK.input_real( self.padder_mask.pad(all_vmasks)[0]) # [*, ?] mention_items_arr, _ = self.padder_items.pad(all_items) # [*, ?] mention_valid = BK.input_real(all_valid) # [*] return mention_widxes, mention_lidxes, mention_vmasks, mention_items_arr, mention_valid
def init_oracle_mask(inst: ParseInstance, prev_arc_mask, prev_label): gold_heads = inst.heads.vals[1:] gold_labels = inst.labels.idxes[1:] gold_idxes = [i + 1 for i in range(len(gold_heads))] prev_arc_mask[gold_idxes, gold_heads] = 1. prev_label[gold_idxes, gold_heads] = BK.input_idx(gold_labels, BK.CPU_DEVICE) return prev_arc_mask, prev_label
def __call__(self, char_input, add_root_token: bool): char_input_t = BK.input_idx(char_input) # [*, slen, wlen] if add_root_token: slice_shape = BK.get_shape(char_input_t) slice_shape[-2] = 1 char_input_t0 = BK.constants(slice_shape, 0, dtype=char_input_t.dtype) # todo(note): simply put 0 here! char_input_t1 = BK.concat([char_input_t0, char_input_t], -2) # [*, 1?+slen, wlen] else: char_input_t1 = char_input_t char_embeds = self.E(char_input_t1) # [*, 1?+slen, wlen, D] char_cat_expr = BK.concat([z(char_embeds) for z in self.char_cnns]) return self.dropout(char_cat_expr) # todo(note): only final dropout
def __call__(self, input_v, add_root_token: bool): if isinstance(input_v, np.ndarray): # direct use this [batch_size, slen] as input posi_idxes = BK.input_idx(input_v) expr = self.node(posi_idxes) # [batch_size, slen, D] else: # input is a shape as prepared by "PosiHelper" batch_size, max_len = input_v if add_root_token: max_len += 1 posi_idxes = BK.arange_idx(max_len) # [1?+slen] add root=0 here expr = self.node(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1) return self.dropout(expr)
def pad_chs(self, idxes_list: List[List], labels_list: List[List]): start_posi = self.chs_start_posi if start_posi < 0: # truncate idxes_list = [x[start_posi:] for x in idxes_list] # overall valid mask chs_valid = [(0. if len(z) == 0 else 1.) for z in idxes_list] # if any valid children in the batch if all(x > 0 for x in chs_valid): padded_chs_idxes, padded_chs_mask = self.ch_idx_padder.pad( idxes_list) # [*, max-ch], [*, max-ch] if self.use_label_feat: if start_posi < 0: # truncate labels_list = [x[start_posi:] for x in labels_list] padded_chs_labels, _ = self.ch_label_padder.pad( labels_list) # [*, max-ch] chs_label_t = BK.input_idx(padded_chs_labels) else: chs_label_t = None chs_idxes_t, chs_mask_t, chs_valid_mask_t = \ BK.input_idx(padded_chs_idxes), BK.input_real(padded_chs_mask), BK.input_real(chs_valid) return chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t else: return None, None, None, None
def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size): conf = self.conf free_mode = (force_widx is None) prev_state_h = prev_state[0] # ===== # collect att scores key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)]) # [*, slen, h] query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)]) # [*, R, h] orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1)) # [*, slen, R] orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN # [*, slen, R] # first maximum across the R dim (this step is hard max) maxr_scores, maxr_idxes = orig_scores.max(-1) # [*, slen] if conf.zero_eos_score: # use mask to make it able to be backward tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.) tmp_mask.index_fill_(-1, BK.input_idx(0), 0.) maxr_scores *= tmp_mask # then select over the slen dim (this step is prob based) maxr_logprobs = BK.log_softmax(maxr_scores) # [*, slen] if free_mode: cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1)) sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False) # [*, beam] else: sel_tok_idxes = force_widx.unsqueeze(-1) # [*, 1] sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes) # [*, 1] # then collect the info and perform labeling lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2) # [*, ?, ~] lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1) # [*, ?, 1] lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)] # [*, ?, ~] # todo(+3): using soft version? lf_prev_state = prev_state_h.unsqueeze(-2) # [*, 1, ~] lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state]) # [*, ?, ~] # final predicting labels # todo(+N): here we select only max at labeling part, only beam at previous one if free_mode: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None) # [*, ?] else: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1)) # no lab-logprob (*=0) for eos (sel_tok==0) sel_lab_logprobs *= (sel_tok_idxes>0).float() # compute next-state [*, ?, ~] # todo(note): here we flatten the first two dims tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1] tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1) tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1)) tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1)) for z in prev_state] # [*, ?, ?, D] next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None) next_state = [z.view(tmp_rnn_dims) for z in next_state] return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): # iconf = self.conf.iconf with BK.no_grad_env(): self.refresh_batch(False) # pruning and scores from g1 valid_mask, go1_pack = self._get_g1_pack( insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) # decode final_valid_expr = self._make_final_valid(valid_mask, mask_expr) ret_heads, ret_labels, _, _ = self.dl.decode( insts, enc_repr, final_valid_expr, go1_pack, False, 0.) # collect the results together all_heads = Helper.join_list(ret_heads) if ret_labels is None: # todo(note): simply get labels from the go1-label classifier; must provide g1parser if go1_pack is None: _, go1_pack = self._get_g1_pack(insts, 1., 1.) _, go1_label_max_idxes = go1_pack[1].max( -1) # [bs, slen, slen] pred_heads_arr, _ = self.predict_padder.pad( all_heads) # [bs, slen] pred_heads_expr = BK.input_idx(pred_heads_arr) pred_labels_expr = BK.gather_one_lastdim( go1_label_max_idxes, pred_heads_expr).squeeze(-1) all_labels = BK.get_value(pred_labels_expr) # [bs, slen] else: all_labels = np.concatenate(ret_labels, 0) # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = len(one_inst) + 1 one_inst.pred_heads.set_vals( all_heads[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( all_labels[one_idx][:cur_length], self.label_vocab) # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length]) # ===== # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} return info
def __call__(self, input, add_root_token: bool): voc = self.voc # todo(note): append a [cls/root] idx, currently use "bos" input_t = BK.input_idx(input) # [*, 1+slen] # rare unk in training if self.rop.training and self.use_rare_unk: rare_unk_rate = self.ec_conf.comp_rare_unk cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long() input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask # root if add_root_token: input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype) # [*, 1+slen] input_t_p1 = BK.concat([input_t_p0, input_t], -1) else: input_t_p1 = input_t expr = self.E(input_t_p1) # [*, 1?+slen] return self.dropout(expr)
def calc_repr(s_enc: SL0Layer, features_group, enc_expr, bidxes_expr): cur_idxes, par_idxes, labels, chs_idxes, chs_labels = features_group # get padded idxes: [*] or [*, ?] cur_idxes_t = BK.input_idx(cur_idxes) par_idxes_t, label_t, par_mask_t = s_enc.pad_par(par_idxes, labels) chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t = s_enc.pad_chs( chs_idxes, chs_labels) # gather enc-expr: [*, D], [*, D], [*, max-chs, D] dim1_range_t = bidxes_expr dim2_range_t = dim1_range_t.unsqueeze(-1) cur_t = enc_expr[dim1_range_t, cur_idxes_t] par_t = enc_expr[dim1_range_t, par_idxes_t] chs_t = None if chs_idxes_t is None else enc_expr[dim2_range_t, chs_idxes_t] # update reprs: [*, D] new_srepr = s_enc.calculate_repr(cur_t, par_t, label_t, par_mask_t, chs_t, chs_label_t, chs_mask_t, chs_valid_mask_t) return cur_idxes_t, new_srepr
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): assert labeled with BK.no_grad_env(): # argmax-label: [BS, m, h] scores_unlabeled_max, labels_argmax = scores_expr.max(-1) # scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max) mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False) # [BS, m] mst_heads_expr = BK.input_idx(mst_heads_arr) mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1) # prepare for the outputs if ret_arr: return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr else: return mst_heads_expr, mst_labels_expr, BK.input_real( mst_scores_arr)
def _decode(self, insts: List[ParseInstance], full_score, mask_expr, misc_prefix): # decode mst_lengths = [len(z) + 1 for z in insts] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj( full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True) if self.conf.iconf.output_marginals: # todo(note): here, we care about marginals for arc # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True) arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) bsize, max_len = BK.get_shape(mask_expr) idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr, BK.input_idx(mst_heads_arr)] mst_marg_arr = BK.get_value(output_marg) else: mst_marg_arr = None # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_labels_arr[one_idx][:cur_length], self.label_vocab) one_scores = mst_scores_arr[one_idx][:cur_length] one_inst.pred_par_scores.set_vals(one_scores) # extra output one_inst.extra_pred_misc[misc_prefix + "_score"] = one_scores.tolist() if mst_marg_arr is not None: one_inst.extra_pred_misc[ misc_prefix + "_marg"] = mst_marg_arr[one_idx][:cur_length].tolist()
def arange_cache(self, bidxes): new_bsize = len(bidxes) # if the idxes are already fine, then no need to select if not Helper.check_is_range(bidxes, self.cur_bsize): # mask is on CPU to make assigning easier bidxes_ct = BK.input_idx(bidxes, BK.CPU_DEVICE) self.scoring_fixed_mask_ct = self.scoring_fixed_mask_ct.index_select( 0, bidxes_ct) self.scoring_mask_ct = self.scoring_mask_ct.index_select( 0, bidxes_ct) self.oracle_mask_ct = self.oracle_mask_ct.index_select( 0, bidxes_ct) # other things are all on target-device (possibly GPU) bidxes_device = BK.to_device(bidxes_ct) self.enc_repr = self.enc_repr.index_select(0, bidxes_device) self.scoring_cache.arange_cache(bidxes_device) # oracles self.oracle_mask_t = self.oracle_mask_t.index_select( 0, bidxes_device) self.oracle_label_t = self.oracle_label_t.index_select( 0, bidxes_device) # update bsize self.update_bsize(new_bsize)
def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): conf = self.conf if self.add_root_token: repr_t = repr_t[:, 1:] mask_t = mask_t[:, 1:] # score scores_t = self._score(repr_t) # [bs, rlen, D] # get gold gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long) # [bs, ?+rlen] for bidx, inst in enumerate(insts): cur_seq_idxes = getattr(inst, self.attr_name).idxes gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes # get loss gold_pidxes_t = BK.input_idx(gold_pidxes) nll_loss_sum = self.neg_log_likelihood_loss(scores_t, mask_t.bool(), gold_pidxes_t) if not conf.div_by_tok: # otherwise div by sent nll_loss_sum *= (mask_t.sum() / len(insts)) # compile loss crf_loss = LossHelper.compile_leaf_info("crf", nll_loss_sum.sum(), mask_t.sum()) return self._compile_component_loss(self.pname, [crf_loss])
def refresh(self, rop=None): super().refresh(rop) # no need to fix0 for None since already done in the Embedding # refresh layered embeddings (in training, we should not be in no-grad mode) # todo(note): here, there can be dropouts layered_prei_arrs = self.hl_vocab.layered_prei layered_pool_links_padded_arrs = self.hl_vocab.layered_pool_links_padded layered_pool_links_mask_arrs = self.hl_vocab.layered_pool_links_mask layered_isnil = self.hl_vocab.layered_pool_isnil for i in range(self.max_layer): # [N, ?, D] -> [N, D] -> [D, N] self.layered_embeds_pred[i] = ( BK.input_real(layered_pool_links_mask_arrs[i]).unsqueeze(-1) * self.pool_pred( layered_pool_links_padded_arrs[i])).sum(-2).transpose( 0, 1).contiguous() # [N, ?, D] -> [N, D] self.layered_embeds_lookup[i] = ( BK.input_real(layered_pool_links_mask_arrs[i]).unsqueeze(-1) * self.pool_lookup(layered_pool_links_padded_arrs[i])).sum(-2) # [?] of idxes/masks self.layered_prei[i] = BK.input_idx(layered_prei_arrs[i]) self.layered_isnil[i] = BK.input_real( layered_isnil[i]) # is nil mask
def fb_on_batch(self, annotated_insts, training=True, loss_factor=1, **kwargs): self.refresh_batch(training) margin = self.margin.value # gold heads and labels gold_heads_arr, _ = self.predict_padder.pad( [z.heads.vals for z in annotated_insts]) gold_labels_arr, _ = self.predict_padder.pad( [self.real2pred_labels(z.labels.idxes) for z in annotated_insts]) gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( annotated_insts, training) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr) # final_losses = None if self.norm_local or self.norm_single: select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # already added margin previously losses_heads = losses_labels = None if self.loss_prob: if self.norm_local: losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=False) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=False) # simply adding final_losses = losses_heads + losses_labels elif self.loss_hinge: if self.norm_local: losses_heads = BK.loss_hinge(full_arc_score, gold_heads_expr) losses_labels = BK.loss_hinge(select_label_score, gold_labels_expr) elif self.norm_single: single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=True, margin=margin) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=True, margin=margin) # simply adding final_losses = losses_heads + losses_labels elif self.loss_mr: # special treatment! probs_heads = BK.softmax(full_arc_score, dim=-1) # [bs, m, h] probs_labels = BK.softmax(select_label_score, dim=-1) # [bs, m, h] # select probs_head_gold = BK.gather_one_lastdim( probs_heads, gold_heads_expr).squeeze(-1) # [bs, m] probs_label_gold = BK.gather_one_lastdim( probs_labels, gold_labels_expr).squeeze(-1) # [bs, m] # root and pad will be excluded later # Reward = \sum_i 1.*marginal(GEdge_i); while for global models, need to gradient on marginal-functions # todo(warn): have problem since steps will be quite small, not used! final_losses = (mask_expr - probs_head_gold * probs_label_gold ) # let loss>=0 elif self.norm_global: full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) # for this one, use the merged full score full_score = full_arc_score.unsqueeze( -1) + full_label_score # [BS, m, h, L] # +=1 to include ROOT for mst decoding mst_lengths_arr = np.asarray([len(z) + 1 for z in annotated_insts], dtype=np.int32) # do inference if self.loss_prob: marginals_expr = self._marginal( full_score, mask_expr, mst_lengths_arr) # [BS, m, h, L] final_losses = self._losses_global_prob( full_score, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr) if self.alg_proj: # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg), # but this might be too loose, although the unproj edges are few? gold_unproj_arr, _ = self.predict_padder.pad( [z.unprojs for z in annotated_insts]) gold_unproj_expr = BK.input_real( gold_unproj_arr) # [BS, Len] comparing_expr = Constants.REAL_PRAC_MIN * ( 1. - gold_unproj_expr) final_losses = BK.max_elem(final_losses, comparing_expr) elif self.loss_hinge: pred_heads_arr, pred_labels_arr, _ = self._decode( full_score, mask_expr, mst_lengths_arr) pred_heads_expr = BK.input_idx(pred_heads_arr) # [BS, Len] pred_labels_expr = BK.input_idx(pred_labels_arr) # [BS, Len] # final_losses = self._losses_global_hinge( full_score, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) elif self.loss_mr: # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges raise NotImplementedError( "Not implemented for global-loss + mr.") elif self.norm_hlocal: # firstly label losses are the same select_label_score = self._score_label_selected( scoring_expr_pack, mask_expr, training, margin, gold_heads_expr, gold_labels_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) # then specially for arc loss children_masks_arr, _ = self.hlocal_padder.pad( [z.get_children_mask_arr() for z in annotated_insts]) children_masks_expr = BK.input_real( children_masks_arr) # [bs, h, m] # [bs, h] # todo(warn): use prod rather than sum, but still only an approximation for the top-down # losses_arc = -BK.log(BK.sum(BK.softmax(full_arc_score, -2).transpose(-1, -2) * children_masks_expr, dim=-1) + (1-mask_expr)) losses_arc = -BK.sum(BK.log_softmax(full_arc_score, -2).transpose( -1, -2) * children_masks_expr, dim=-1) # including the root-head is important losses_arc[:, 1] += losses_arc[:, 0] final_losses = losses_arc + losses_labels # # jpos loss? (the same mask as parsing) jpos_losses_expr = jpos_pack[1] if jpos_losses_expr is not None: final_losses += jpos_losses_expr # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent # final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } if training: BK.backward(final_loss, loss_factor) return info
def _decode(self, mb_insts: List[ParseInstance], mb_enc_expr, mb_valid_expr, mb_go1_pack, training: bool, margin: float): # ===== use_sib, use_gp = self.use_sib, self.use_gp # ===== mb_size = len(mb_insts) mat_shape = BK.get_shape(mb_valid_expr) max_slen = mat_shape[-1] # step 1: extract the candidate features batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes = self.helper.get_cand_features( mb_valid_expr) # ===== # step 2: high order scoring # step 2.1: basic scoring, [*], [*, Lab] arc_scores, lab_scores = self._get_basic_score(mb_enc_expr, batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes) cur_system_labeled = (lab_scores is not None) # step 2.2: margin # get gold labels, which can be useful for later calculating loss if training: gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = \ [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_insts(mb_insts, use_sib, use_gp)] # add the margins to the scores: (m,h), (m,sib), (m,gp) cur_margin = margin / self.margin_div self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_lab_idxes, batch_idxes, m_idxes, h_idxes, arc_scores, lab_scores, cur_margin, cur_margin) if use_sib: self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_sib_idxes, gold_lab_idxes, batch_idxes, m_idxes, sib_idxes, arc_scores, lab_scores, cur_margin, cur_margin) if use_gp: self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_gp_idxes, gold_lab_idxes, batch_idxes, m_idxes, gp_idxes, arc_scores, lab_scores, cur_margin, cur_margin) # may be useful for later training gold_pack = (mb_size, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes) else: gold_pack = None # step 2.3: o1scores if mb_go1_pack is not None: go1_arc_scores, go1_lab_scores = mb_go1_pack # todo(note): go1_arc_scores is not added here, but as the input to the dec-algo if cur_system_labeled: lab_scores += go1_lab_scores[batch_idxes, m_idxes, h_idxes] else: go1_arc_scores = None # step 2.4: max out labels; todo(+N): or using logsumexp here? if cur_system_labeled: max_lab_scores, max_lab_idxes = lab_scores.max(-1) final_scores = arc_scores + max_lab_scores # [*], final input arc scores else: max_lab_idxes = None final_scores = arc_scores # ===== # step 3: actual decode res_heads = [] for sid, inst in enumerate(mb_insts): slen = len(inst) + 1 # plus one for the art-root arr_o1_masks = BK.get_value(mb_valid_expr[sid, :slen, :slen].int()) arr_o1_scores = BK.get_value( go1_arc_scores[sid, :slen, :slen].double()) if ( go1_arc_scores is not None) else None cur_bidx_mask = (batch_idxes == sid) input_pack = [m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores] one_heads = self.helper.decode_one(slen, self.projective, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask) res_heads.append(one_heads) # ===== # step 4: get labels back and pred_pack pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, _ = \ [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_preds(res_heads, None, use_sib, use_gp)] if cur_system_labeled: # obtain hit components pred_hit_mask = self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_h_idxes, batch_idxes, m_idxes, h_idxes) if use_sib: pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_sib_idxes, batch_idxes, m_idxes, sib_idxes) if use_gp: pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_gp_idxes, batch_idxes, m_idxes, gp_idxes) # get pred labels (there should be only one hit per mod!) pred_labels = BK.constants_idx([mb_size, max_slen], 0) pred_labels[batch_idxes[pred_hit_mask], m_idxes[pred_hit_mask]] = max_lab_idxes[pred_hit_mask] res_labels = BK.get_value(pred_labels) pred_lab_idxes = pred_labels[pred_b_idxes, pred_m_idxes] else: res_labels = None pred_lab_idxes = None pred_pack = (mb_size, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes) # return return res_heads, res_labels, gold_pack, pred_pack
def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1., rand_gen=None, assign_attns=False, **kwargs): # ===== # import torch # torch.autograd.set_detect_anomaly(True) # with torch.autograd.detect_anomaly(): # ===== conf = self.conf self.refresh_batch(training) if len(insts) == 0: return {"fb": 0, "sent": 0, "tok": 0} # ----- # copying instances for training: expand at dim0 cur_copy = conf.train_inst_copy if training else 1 copied_insts = insts * cur_copy all_losses = [] # ----- # original input input_map = self.inputter(copied_insts) # for the pretraining modules has_loss_mlm, has_loss_orp = (self.masklm.loss_lambda.value > 0.), (self.orderpr.loss_lambda.value > 0.) if (not has_loss_orp) and has_loss_mlm: # only for mlm masked_input_map, input_erase_mask_arr = self.masklm.mask_input(input_map, rand_gen=rand_gen) emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(masked_input_map, collect_loss=True) all_losses.append(enc_loss) # mlm loss; todo(note): currently only using one layer mlm_loss = self.masklm.loss(enc_t, input_erase_mask_arr, input_map) all_losses.append(mlm_loss) # assign values if assign_attns: # may repeat and only keep that last one, but does not matter! self._assign_attns_item(copied_insts, "mask", input_erase_mask_arr=input_erase_mask_arr, cache=cache) # agreement loss if cur_copy > 1: all_losses.extend(self._get_agr_loss("agr_mlm", cache, copy_num=cur_copy)) if has_loss_orp: disturbed_input_map = self.orderpr.disturb_input(input_map, rand_gen=rand_gen) if has_loss_mlm: # further mask some disturb_keep_arr = disturbed_input_map.get("disturb_keep", None) assert disturb_keep_arr is not None, "No keep region for mlm!" # todo(note): in this mode we assume add_root, so here exclude arti-root by [:,1:] masked_input_map, input_erase_mask_arr = \ self.masklm.mask_input(input_map, rand_gen=rand_gen, extra_mask_arr=disturb_keep_arr[:,1:]) disturbed_input_map.update(masked_input_map) # update emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(disturbed_input_map, collect_loss=True) all_losses.append(enc_loss) # orp loss if conf.orp_loss_special: orp_loss = self.orderpr.loss_special(enc_t, mask_t, disturbed_input_map.get("disturb_keep", None), disturbed_input_map, self.masklm) else: orp_input_attn = self.prepr_f(cache, disturbed_input_map.get("rel_dist")) orp_loss = self.orderpr.loss(enc_t, orp_input_attn, mask_t, disturbed_input_map.get("disturb_keep", None)) all_losses.append(orp_loss) # mlm loss if has_loss_mlm: mlm_loss = self.masklm.loss(enc_t, input_erase_mask_arr, input_map) all_losses.append(mlm_loss) # assign values if assign_attns: # may repeat and only keep that last one, but does not matter! self._assign_attns_item(copied_insts, "dist", abs_posi_arr=disturbed_input_map.get("posi"), cache=cache) # agreement loss if cur_copy > 1: all_losses.extend(self._get_agr_loss("agr_orp", cache, copy_num=cur_copy)) if self.plainlm.loss_lambda.value > 0.: if conf.enc_choice == "vrec": # special case for blm emb_t, mask_t = self.embedder(input_map) rel_dist = input_map.get("rel_dist", None) if rel_dist is not None: rel_dist = BK.input_idx(rel_dist) # two directions true_rel_dist = self._get_rel_dist(BK.get_shape(mask_t, -1)) # q-k: [len_q, len_k] enc_t1, cache1, enc_loss1 = self.encoder(emb_t, src_mask=mask_t, qk_mask=(true_rel_dist<=0).float(), rel_dist=rel_dist, collect_loss=True) enc_t2, cache2, enc_loss2 = self.encoder(emb_t, src_mask=mask_t, qk_mask=(true_rel_dist>=0).float(), rel_dist=rel_dist, collect_loss=True) assert not self.rpreper.active, "TODO: Not supported for this mode" all_losses.extend([enc_loss1, enc_loss2]) # plm loss with explict two inputs plm_loss = self.plainlm.loss([enc_t1, enc_t2], input_map) all_losses.append(plm_loss) else: # here use original input emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(input_map, collect_loss=True) all_losses.append(enc_loss) # plm loss plm_loss = self.plainlm.loss(enc_t, input_map) all_losses.append(plm_loss) # agreement loss assert self.lambda_agree.value==0., "Not implemented for this mode" # ===== # task loss dpar_loss_lambda, upos_loss_lambda, ner_loss_lambda = \ [0. if z is None else z.loss_lambda.value for z in [self.dpar, self.upos, self.ner]] if any(z>0. for z in [dpar_loss_lambda, upos_loss_lambda, ner_loss_lambda]): # here use original input emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(input_map, collect_loss=True, insts=insts) all_losses.append(enc_loss) # parsing loss if dpar_loss_lambda > 0.: dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1))) dpar_loss = self.dpar.loss(copied_insts, enc_t, dpar_input_attn, mask_t) all_losses.append(dpar_loss) # pos loss if upos_loss_lambda > 0.: upos_loss = self.upos.loss(copied_insts, enc_t, mask_t) all_losses.append(upos_loss) # ner loss if ner_loss_lambda > 0.: ner_loss = self.ner.loss(copied_insts, enc_t, mask_t) all_losses.append(ner_loss) # ----- info = self.collect_loss_and_backward(all_losses, training, loss_factor) info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)}) return info
def _get_rel_dist_embed(self, rel_dist, use_abs: bool): if use_abs: rel_dist = BK.input_idx(rel_dist).abs() ret = self.rel_dist_embed(rel_dist) # [bs, len, len, H] return ret
def loss(self, repr_t, orig_map: Dict, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # -- # specify input add_root_token = self.add_root_token # get from inputs if isinstance(repr_t, (list, tuple)): l2r_repr_t, r2l_repr_t = repr_t elif self.split_input_blm: l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1) else: l2r_repr_t, r2l_repr_t = repr_t, None # l2r and r2l word_t = BK.input_idx(orig_map["word"]) # [bs, rlen] slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long() # [bs, 1] if add_root_token: l2r_trg_t = BK.concat([word_t, slice_zero_t], -1) # pad one extra 0, [bs, rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, slice_zero_t, word_t[:, :-1]], -1) # pad two extra 0 at front, [bs, 2+rlen-1] else: l2r_trg_t = BK.concat( [word_t[:, 1:], slice_zero_t], -1 ) # pad one extra 0, but remove the first one, [bs, -1+rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, word_t[:, :-1]], -1) # pad one extra 0 at front, [bs, 1+rlen-1] # gather the losses all_losses = [] pred_range_min, pred_range_max = max( 1, conf.min_pred_rank), self.pred_size - 1 if _tie_input_embeddings: pred_W = self.inputter_embed_node.E.E[:self. pred_size] # [PSize, Dim] else: pred_W = None # get input embeddings for output for pred_name, hid_node, pred_node, input_t, trg_t in \ zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred], [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]): if input_t is None: continue # hidden hid_t = hid_node( input_t) if hid_node else input_t # [bs, slen, hid] # pred: [bs, slen, Vsize] if _tie_input_embeddings: scores_t = BK.matmul(hid_t, pred_W.T) else: scores_t = pred_node(hid_t) # loss mask_t = ((trg_t >= pred_range_min) & (trg_t <= pred_range_max)).float() # [bs, slen] trg_t.clamp_(max=pred_range_max) # make it in range losses_t = BK.loss_nll(scores_t, trg_t) * mask_t # [bs, slen] _, argmax_idxes = scores_t.max(-1) # [bs, slen] corrs_t = (argmax_idxes == trg_t).float() * mask_t # [bs, slen] # compile leaf loss one_loss = LossHelper.compile_leaf_info(pred_name, losses_t.sum(), mask_t.sum(), loss_lambda=1., corr=corrs_t.sum()) all_losses.append(one_loss) return self._compile_component_loss("plm", all_losses)