コード例 #1
0
 def fb_on_batch(self,
                 annotated_insts: List[ParseInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     # encode
     input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
         annotated_insts, training)
     mask_expr = BK.input_real(mask_arr)
     # the parsing loss
     arc_score = self.scorer_helper.score_arc(enc_repr)
     lab_score = self.scorer_helper.score_label(enc_repr)
     full_score = arc_score + lab_score
     parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr)
     # other loss?
     jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
     reg_loss = self.reg_scores_loss(arc_score, lab_score)
     #
     info["loss_parse"] = BK.get_value(parsing_loss).item()
     final_loss = parsing_loss
     if jpos_loss is not None:
         info["loss_jpos"] = BK.get_value(jpos_loss).item()
         final_loss = parsing_loss + jpos_loss
     if reg_loss is not None:
         final_loss = final_loss + reg_loss
     info["fb"] = 1
     if training:
         BK.backward(final_loss, loss_factor)
     return info
コード例 #2
0
ファイル: s2p.py プロジェクト: ValentinaPy/zmsp
 def fb_on_batch(self,
                 annotated_insts: List[ParseInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     # todo(note): here always using training lambdas
     full_score, original_scores, jpos_pack, mask_expr, valid_mask_d, _ = \
         self._score(annotated_insts, False, self.lambda_g1_arc_training, self.lambda_g1_lab_training)
     parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr,
                                     valid_mask_d)
     # other loss?
     jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
     reg_loss = self.reg_scores_loss(*original_scores)
     #
     info["loss_parse"] = BK.get_value(parsing_loss).item()
     final_loss = parsing_loss
     if jpos_loss is not None:
         info["loss_jpos"] = BK.get_value(jpos_loss).item()
         final_loss = parsing_loss + jpos_loss
     if reg_loss is not None:
         final_loss = final_loss + reg_loss
     info["fb"] = 1
     if training:
         BK.backward(final_loss, loss_factor)
     return info
コード例 #3
0
ファイル: model.py プロジェクト: ValentinaPy/zmsp
 def _inference_mentions(self, insts: List[Sentence], lexi_repr, enc_repr, mask_expr, extractor: NodeExtractorBase, item_creator):
     sel_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds = \
         extractor.predict(insts, lexi_repr, enc_repr, mask_expr)
     # handling outputs here: prepare new items
     head_idxes_arr = BK.get_value(sel_idxes)  # [*, max-count]
     lab_idxes_arr = BK.get_value(sel_lab_idxes)  # [*, max-count]
     logprobs_arr = BK.get_value(sel_logprobs)  # [*, max-count]
     valid_arr = BK.get_value(sel_valid_mask)  # [*, max-count]
     all_items = []
     bsize, mc = valid_arr.shape
     for one_idxes, one_valids, one_lab_idxes, one_logprobs, one_sent in \
             zip(head_idxes_arr, valid_arr, lab_idxes_arr, logprobs_arr, insts):
         sid = one_sent.sid
         partial_id0 = f"{one_sent.doc.doc_id}-s{one_sent.sid}-i"
         for this_i in range(mc):
             this_valid = float(one_valids[this_i])
             if this_valid == 0:  # must be compact
                 assert np.all(one_valids[this_i:]==0.)
                 all_items.extend([None] * (mc-this_i))
                 break
             # todo(note): we need to assign various info at the outside
             this_mention = Mention(HardSpan(sid, int(one_idxes[this_i]), None, None))
             # todo(note): where to filter None?
             this_hlidx = extractor.idx2hlidx(one_lab_idxes[this_i])
             all_items.append(item_creator(partial_id0+str(this_i), this_mention, this_hlidx, float(one_logprobs[this_i])))
     # only return the items and the ones useful for later steps: List(sent)[List(items)], *[*, max-count]
     ret_items = np.asarray(all_items, dtype=object).reshape((bsize, mc))
     return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
コード例 #4
0
ファイル: g2p.py プロジェクト: ValentinaPy/zmsp
 def fb_on_batch(self,
                 annotated_insts: List[ParseInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     # pruning and scores from g1
     valid_mask, go1_pack = self._get_g1_pack(annotated_insts,
                                              self.lambda_g1_arc_training,
                                              self.lambda_g1_lab_training)
     # encode
     input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
         annotated_insts, training)
     mask_expr = BK.input_real(mask_arr)
     # the parsing loss
     final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
     parsing_loss, parsing_scores, info = \
         self.dl.loss(annotated_insts, enc_repr, final_valid_expr, go1_pack, True, self.margin.value)
     info["loss_parse"] = BK.get_value(parsing_loss).item()
     final_loss = parsing_loss
     # other loss?
     jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
     if jpos_loss is not None:
         info["loss_jpos"] = BK.get_value(jpos_loss).item()
         final_loss = parsing_loss + jpos_loss
     if parsing_scores is not None:
         reg_loss = self.reg_scores_loss(*parsing_scores)
         if reg_loss is not None:
             final_loss = final_loss + reg_loss
     info["fb"] = 1
     if training:
         BK.backward(final_loss, loss_factor)
     return info
コード例 #5
0
ファイル: score.py プロジェクト: ValentinaPy/zmsp
def main(args):
    conf, model, vpack, test_iter = prepare_test(args)
    dconf = conf.dconf
    # todo(note): here is the main change
    # make sure the model is order 1 graph model, otherwise cannot run through
    all_results = []
    all_insts = []
    with utils.Timer(tag="Run-score", info="", print_date=True):
        for cur_insts in test_iter:
            all_insts.extend(cur_insts)
            batched_arc_scores, batched_label_scores = model.score_on_batch(
                cur_insts)
            batched_arc_scores, batched_label_scores = BK.get_value(
                batched_arc_scores), BK.get_value(batched_label_scores)
            for cur_idx in range(len(cur_insts)):
                cur_len = len(cur_insts[cur_idx]) + 1
                # discarding paddings
                cur_res = (batched_arc_scores[cur_idx, :cur_len, :cur_len],
                           batched_label_scores[cur_idx, :cur_len, :cur_len])
                all_results.append(cur_res)
    # reorder to the original order
    orig_indexes = [z.inst_idx for z in all_insts]
    orig_results = [None] * len(orig_indexes)
    for new_idx, orig_idx in enumerate(orig_indexes):
        assert orig_results[orig_idx] is None
        orig_results[orig_idx] = all_results[new_idx]
    # saving
    with utils.Timer(tag="Run-write",
                     info=f"Writing to {dconf.output_file}",
                     print_date=True):
        import pickle
        with utils.zopen(dconf.output_file, "wb") as fd:
            for one in orig_results:
                pickle.dump(one, fd)
    utils.printing("The end.")
コード例 #6
0
ファイル: model.py プロジェクト: ValentinaPy/zmsp
 def _inference_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef,
                     evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt):
     arg_linker = self.arg_linker
     repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2)  # [*, len-ef, D]
     repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2)  # [*, len-evt, D]
     role_logprobs, role_predictions = arg_linker.predict(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes,
                                                          ef_valid_mask, evt_valid_mask)
     # add them inplaced
     roles_arr = BK.get_value(role_predictions)  # [*, len-ef, len-evt]
     logprobs_arr = BK.get_value(role_logprobs)
     for bidx, one_roles_arr in enumerate(roles_arr):
         one_ef_items, one_evt_items = ef_items[bidx], evt_items[bidx]
         # =====
         # todo(note): delete origin links!
         for z in one_ef_items:
             if z is not None:
                 z.links.clear()
         for z in one_evt_items:
             if z is not None:
                 z.links.clear()
         # =====
         one_logprobs = logprobs_arr[bidx]
         for ef_idx, one_ef in enumerate(one_ef_items):
             if one_ef is None:
                 continue
             for evt_idx, one_evt in enumerate(one_evt_items):
                 if one_evt is None:
                     continue
                 one_role_idx = int(one_roles_arr[ef_idx, evt_idx])
                 if one_role_idx > 0:  # link
                     this_hlidx = arg_linker.idx2hlidx(one_role_idx)
                     one_evt.add_arg(one_ef, role=str(this_hlidx), role_idx=this_hlidx,
                                     score=float(one_logprobs[ef_idx, evt_idx]))
コード例 #7
0
 def lookup(self, insts: List, input_lexi, input_expr, input_mask):
     bsize = len(insts)
     # get gold or pre-set ones, again [*, slen, L] -> [*, mc]
     gold_masks, _, gold_items_arr, gold_valid = self.batch_inputs_g0(insts)
     sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds = self._pmask2idxes(
         gold_masks)
     ret_items = gold_items_arr[np.arange(bsize)[:, np.newaxis],
                                BK.get_value(sel_idxes),
                                BK.get_value(sel_lab_idxes)]
     return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
コード例 #8
0
ファイル: head.py プロジェクト: ValentinaPy/zmsp
 def lookup(self, insts: List, input_lexi, input_expr, input_mask):
     conf = self.conf
     bsize = len(insts)
     # first get gold/input info, also multiple valid-masks
     gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h(
         insts)
     # step 1: no selection, simply forward using gold_masks
     sel_idxes, sel_valid_mask = BK.mask2idx(gold_masks)  # [*, max-count]
     sel_gold_idxes = gold_idxes.gather(-1, sel_idxes)
     sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes)
     # todo(+N): only get items by head position!
     _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value(
         sel_idxes)
     sel_items = gold_items_arr[_tmp_i0, _tmp_i1]  # [*, mc]
     sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1]
     # step 2: encoding and labeling
     sel_shape = BK.get_shape(sel_idxes)
     if sel_shape[-1] == 0:
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim])
         ret_items = sel_items  # dim-1==0
     else:
         # sel_hid_exprs = self._enc(input_expr, input_mask, sel_idxes)  # [*, mc, DLab]
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = self.hl.lookup(
             sel_lab_idxes)  # todo(note): here no softlookup?
         ret_items = sel_items
         # second type
         if self.use_secondary_type:
             sel2_lab_idxes = sel_gold_idxes2
             sel2_lab_embeds = self.hl.lookup(
                 sel2_lab_idxes)  # todo(note): here no softlookup?
             sel2_valid_mask = (sel2_lab_idxes > 0).float()
             # combine the two
             if sel2_lab_idxes.sum().item(
             ) > 0:  # if there are any gold sectypes
                 ret_items = np.concatenate([ret_items, sel2_items],
                                            -1)  # [*, mc*2]
                 sel_idxes = BK.concat([sel_idxes, sel_idxes], -1)
                 sel_valid_mask = BK.concat(
                     [sel_valid_mask, sel2_valid_mask], -1)
                 sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes],
                                           -1)
                 sel_lab_embeds = BK.concat(
                     [sel_lab_embeds, sel2_lab_embeds], -2)
     # step 3: exclude nil assuming no deliberate nil in gold/inputs
     if conf.exclude_nil:  # [*, mc', ...]
         sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \
             self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items)
     # step 4: return
     # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2)  # [*, mc', D]
     # mask out invalid items with None
     ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None
     return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
コード例 #9
0
ファイル: g2p.py プロジェクト: ValentinaPy/zmsp
 def decode_one(self, slen: int, projective: bool, arr_o1_masks,
                arr_o1_scores, input_pack, cur_bidx_mask):
     m_idxes, h_idxes, _, _, final_scores = input_pack
     if arr_o1_scores is None:
         arr_o1_scores = np.full([slen, slen], 0., dtype=np.double)
     # direct add to the scores
     m_idxes, h_idxes, final_scores = m_idxes[cur_bidx_mask], h_idxes[
         cur_bidx_mask], final_scores[cur_bidx_mask]
     arr_o1_scores[BK.get_value(m_idxes),
                   BK.get_value(h_idxes)] += BK.get_value(final_scores)
     return hop_decode(slen, projective, arr_o1_masks, arr_o1_scores, None,
                       None, None)
コード例 #10
0
ファイル: modelR.py プロジェクト: ValentinaPy/zmsp
 def _pred_and_put_res(self, predictor, hidden_t, evt_arr, put_f):
     logits = predictor(hidden_t)  # [bsize, ?, Out]
     log_probs = BK.log_softmax(logits, -1)
     max_log_probs, max_label_idxes = log_probs.max(
         -1)  # [bs, ?], simply argmax prediction
     max_log_probs_arr, max_label_idxes_arr = BK.get_value(
         max_log_probs), BK.get_value(max_label_idxes)
     for evt_row, lprob_row, lidx_row in zip(evt_arr, max_log_probs_arr,
                                             max_label_idxes_arr):
         for one_evt, one_lprob, one_lidx in zip(evt_row, lprob_row,
                                                 lidx_row):
             if one_evt is not None:
                 put_f(one_evt, one_lprob,
                       one_lidx)  # callback for inplace setting
コード例 #11
0
ファイル: base_search.py プロジェクト: ValentinaPy/zmsp
 def _new_states(self, flattened_states: List[EfState], scoring_mask_ct,
                 topk_arc_scores, topk_m, topk_h, topk_label_scores,
                 topk_label_idxes):
     topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes = \
         (BK.get_value(z) for z in (topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes))
     new_states = []
     # for each batch element
     for one_state, one_mask, one_arc_scores, one_ms, one_hs, one_label_scores, one_labels in \
             zip(flattened_states, scoring_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes):
         one_new_states = []
         # for each of the k arc selection
         for cur_arc_score, cur_m, cur_h, cur_label_scores, cur_labels in \
                 zip(one_arc_scores, one_ms, one_hs, one_label_scores, one_labels):
             # first need that selection to be valid
             cur_arc_score, cur_m, cur_h = cur_arc_score.item(), cur_m.item(
             ), cur_h.item()
             if one_mask[cur_m, cur_h].item() > 0.:
                 # for each of the label
                 for this_label_score, this_label in zip(
                         cur_label_scores, cur_labels):
                     this_label_score, this_label = this_label_score.item(
                     ), this_label.item()
                     # todo(note): actually add new state; do not include label score if label does not come from ef
                     cur_all_score = (
                         cur_arc_score + this_label_score
                     ) if self.system_labeled else cur_arc_score
                     this_new_state = one_state.build_next(
                         action=EfAction(cur_h, cur_m, this_label),
                         score=cur_all_score)
                     one_new_states.append(this_new_state)
         new_states.append(one_new_states)
     return new_states
コード例 #12
0
ファイル: mtl.py プロジェクト: ValentinaPy/zmsp
 def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None):
     conf = self.conf
     # -----
     # special mode
     if conf.aug_word2 and conf.aug_word2_aug_encoder:
         _rop = RefreshOptions(training=False)  # special feature-mode!!
         self.embedder.refresh(_rop)
         self.encoder.refresh(_rop)
     # -----
     emb_t, mask_t = self.embedder(cur_input_map)
     rel_dist = cur_input_map.get("rel_dist", None)
     if rel_dist is not None:
         rel_dist = BK.input_idx(rel_dist)
     if conf.enc_choice == "vrec":
         enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss)
     elif conf.enc_choice == "original":  # todo(note): change back to arr for back compatibility
         assert rel_dist is None, "Original encoder does not support rel_dist"
         enc_t = self.encoder(emb_t, BK.get_value(mask_t))
         cache, enc_loss = None, None
     else:
         raise NotImplementedError()
     # another encoder based on attn
     final_enc_t = self.rpreper(emb_t, enc_t, cache)  # [*, slen, D] => final encoder output
     if conf.aug_word2:
         emb2_t = self.aug_word2(insts)
         if conf.aug_word2_aug_encoder:
             # simply add them all together, detach orig-enc as features
             stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach()
             features = self.aug_mixturer(stack_hidden_t)
             aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features))
             final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t,
                                                             rel_dist=rel_dist, collect_loss=collect_loss)
         else:
             final_enc_t = (final_enc_t + emb2_t)  # otherwise, simply adding
     return emb_t, mask_t, final_enc_t, cache, enc_loss
コード例 #13
0
ファイル: nmst.py プロジェクト: ValentinaPy/zmsp
def nmst_greedy(scores_expr,
                mask_expr,
                lengths_arr,
                labeled=True,
                ret_arr=False):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # mask out diag
        scores_expr += BK.diagflat(
            BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1)
        # combined last two dimension and Max over them
        combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1])
        combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr,
                                                        dim=-1)
        # back to real idxes
        last_size = scores_shape[-1]
        greedy_heads = combined_max_idxes // last_size
        greedy_labels = combined_max_idxes % last_size
        if ret_arr:
            mst_heads_arr, mst_labels_arr, mst_scores_arr = [
                BK.get_value(z)
                for z in (greedy_heads, greedy_labels, combine_max_scores)
            ]
            return mst_heads_arr, mst_labels_arr, mst_scores_arr
        else:
            return greedy_heads, greedy_labels, combine_max_scores
コード例 #14
0
 def jpos_decode(self, insts: List[ParseInstance], jpos_pack):
     # jpos prediction (directly index, no converting)
     jpos_preds_expr = jpos_pack[2]
     if jpos_preds_expr is not None:
         jpos_preds_arr = BK.get_value(jpos_preds_expr)
         for one_idx, one_inst in enumerate(insts):
             cur_length = len(one_inst) + 1  # including the artificial ROOT
             one_inst.pred_poses.build_vals(
                 jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab)
コード例 #15
0
ファイル: model_expand.py プロジェクト: ValentinaPy/zmsp
 def predict(self, ms_items: List, bert_expr):
     conf = self.conf
     bsize = len(ms_items)
     # collect instances
     col_efs, col_sents, col_bidxes_t, col_hidxes_t, _, _ = self._collect_insts(
         ms_items, False)
     if len(col_efs) == 0:
         return
     left_scores, right_scores = self._score(bert_expr, col_bidxes_t,
                                             col_hidxes_t)
     if conf.use_binary_scorer:
         lscores_arr, rscores_arr = BK.get_value(left_scores), BK.get_value(
             right_scores)
         #
         for one_ef, one_sent, one_lscores, one_rscores in zip(
                 col_efs, col_sents, lscores_arr, rscores_arr):
             one_ldist, one_rdist = self._binary_decide_dist(
                 one_lscores), self._binary_decide_dist(one_rscores)
             # set span
             hspan = one_ef.mention.hard_span
             sid, head_wid = hspan.sid, hspan.head_wid
             left_wid = max(1,
                            head_wid - one_ldist)  # not the artificial root
             right_wid = min(one_sent.length - 1, head_wid + one_rdist)
             hspan.wid = left_wid
             hspan.length = right_wid - left_wid + 1
     else:
         # simply pick max
         _, left_max_dist = left_scores.max(-1)
         _, right_max_dist = right_scores.max(-1)
         lmax_arr, rmax_arr = BK.get_value(left_max_dist), BK.get_value(
             right_max_dist)
         #
         for one_ef, one_sent, one_ldist, one_rdist in zip(
                 col_efs, col_sents, lmax_arr, rmax_arr):
             one_ldist, one_rdist = int(one_ldist), int(one_rdist)
             # set span
             hspan = one_ef.mention.hard_span
             sid, head_wid = hspan.sid, hspan.head_wid
             left_wid = max(1,
                            head_wid - one_ldist)  # not the artificial root
             right_wid = min(one_sent.length - 1, head_wid + one_rdist)
             hspan.wid = left_wid
             hspan.length = right_wid - left_wid + 1
コード例 #16
0
 def collect_pruning_info(insts: List[ParseInstance], valid_mask_f):
     # two dimensions: coverage and pruning-effect
     maxlen = BK.get_shape(valid_mask_f, -1)
     # 1. coverage
     valid_mask_f_flattened = valid_mask_f.view([-1,
                                                 maxlen])  # [bs*len, len]
     cur_mod_base = 0
     all_mods, all_heads = [], []
     for cur_idx, cur_inst in enumerate(insts):
         for m, h in enumerate(cur_inst.heads.vals[1:], 1):
             all_mods.append(m + cur_mod_base)
             all_heads.append(h)
         cur_mod_base += maxlen
     cov_count = len(all_mods)
     cov_valid = BK.get_value(
         valid_mask_f_flattened[all_mods, all_heads].sum()).item()
     # 2. pruning-rate
     # todo(warn): to speed up, these stats are approximate because of including paddings
     # edges
     pr_edges = int(np.prod(BK.get_shape(valid_mask_f)))
     pr_edges_valid = BK.get_value(valid_mask_f.sum()).item()
     # valid as structured heads
     pr_o2_sib = pr_o2_g = pr_edges
     pr_o3_gsib = maxlen * pr_edges
     valid_chs_counts, valid_par_counts = valid_mask_f.sum(
         -2), valid_mask_f.sum(-1)  # [*, len]
     valid_gsibs = valid_chs_counts * valid_par_counts
     pr_o2_sib_valid = BK.get_value(valid_chs_counts.sum()).item()
     pr_o2_g_valid = BK.get_value(valid_par_counts.sum()).item()
     pr_o3_gsib_valid = BK.get_value(valid_gsibs.sum()).item()
     return {
         "cov_count": cov_count,
         "cov_valid": cov_valid,
         "pr_edges": pr_edges,
         "pr_edges_valid": pr_edges_valid,
         "pr_o2_sib": pr_o2_sib,
         "pr_o2_g": pr_o2_g,
         "pr_o3_gsib": pr_o3_gsib,
         "pr_o2_sib_valid": pr_o2_sib_valid,
         "pr_o2_g_valid": pr_o2_g_valid,
         "pr_o3_gsib_valid": pr_o3_gsib_valid
     }
コード例 #17
0
ファイル: g2p.py プロジェクト: ValentinaPy/zmsp
 def decode_one(self, slen: int, projective: bool, arr_o1_masks,
                arr_o1_scores, input_pack, cur_bidx_mask):
     m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores = input_pack
     o3gsib_pack = [
         m_idxes[cur_bidx_mask].int(), h_idxes[cur_bidx_mask].int(),
         sib_idxes[cur_bidx_mask].int(), gp_idxes[cur_bidx_mask].int(),
         final_scores[cur_bidx_mask].double()
     ]
     o3gsib_arr_pack = [BK.get_value(z) for z in o3gsib_pack]
     return hop_decode(slen, projective, arr_o1_masks, arr_o1_scores, None,
                       None, o3gsib_arr_pack)
コード例 #18
0
ファイル: nmst.py プロジェクト: ValentinaPy/zmsp
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr,
                                                 lengths_arr,
                                                 labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax,
                                                mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(
                mst_scores_arr)
コード例 #19
0
ファイル: nmst.py プロジェクト: ValentinaPy/zmsp
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr, dim=-1)  # [BS, m, h]
        # marginal for unlabeled
        scores_unlabeled_arr = BK.get_value(scores_unlabeled)
        marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr,
                                                lengths_arr, False)
        # back to labeled values
        marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr)
        marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(
            -1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1))
        # [BS, m, h, L]
        return _ensure_margins_norm(marginals_labeled_expr)
コード例 #20
0
ファイル: efp.py プロジェクト: ValentinaPy/zmsp
 def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs):
     self.refresh_batch(training)
     # encode
     input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(annotated_insts, training)
     mask_expr = BK.input_real(mask_arr)
     # g1 score
     g1_pack = self._get_g1_pack(annotated_insts, self.lambda_g1_arc_training, self.lambda_g1_lab_training)
     # the parsing loss
     parsing_loss, parsing_scores, info = self.losser.loss(annotated_insts, enc_repr, mask_arr, g1_pack)
     # whether add jpos loss?
     jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
     #
     no_loss = True
     final_loss = 0.
     if parsing_loss is None:
         info["loss_parse"] = 0.
     else:
         final_loss = final_loss + parsing_loss
         info["loss_parse"] = BK.get_value(parsing_loss).item()
         no_loss = False
     if jpos_loss is None:
         info["loss_jpos"] = 0.
     else:
         final_loss = final_loss + jpos_loss
         info["loss_jpos"] = BK.get_value(jpos_loss).item()
         no_loss = False
     if parsing_scores is not None:
         arc_scores, lab_scores = parsing_scores
         reg_loss = self.reg_scores_loss(arc_scores, lab_scores)
         if reg_loss is not None:
             final_loss = final_loss + reg_loss
     info["fb"] = 1
     if training and not no_loss:
         info["fb_back"] = 1
         BK.backward(final_loss, loss_factor)
     return info
コード例 #21
0
 def predict(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
     conf = self.conf
     # score
     scores_t = self._score(repr_t)  # [bs, ?+rlen, D]
     _, argmax_idxes = scores_t.max(-1)  # [bs, ?+rlen]
     argmax_idxes_arr = BK.get_value(argmax_idxes)  # [bs, ?+rlen]
     # assign; todo(+2): record scores?
     one_offset = int(self.add_root_token)
     for one_bidx, one_inst in enumerate(insts):
         one_pidxes = argmax_idxes_arr[one_bidx, one_offset:one_offset +
                                       len(one_inst)].tolist()
         one_pseq = SeqField(None)
         one_pseq.build_vals(one_pidxes, self.vocab)
         one_inst.add_item("pred_" + self.attr_name,
                           one_pseq,
                           assert_non_exist=False)
     return
コード例 #22
0
ファイル: g2p.py プロジェクト: ValentinaPy/zmsp
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     # iconf = self.conf.iconf
     with BK.no_grad_env():
         self.refresh_batch(False)
         # pruning and scores from g1
         valid_mask, go1_pack = self._get_g1_pack(
             insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing)
         # encode
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
             insts, False)
         mask_expr = BK.input_real(mask_arr)
         # decode
         final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
         ret_heads, ret_labels, _, _ = self.dl.decode(
             insts, enc_repr, final_valid_expr, go1_pack, False, 0.)
         # collect the results together
         all_heads = Helper.join_list(ret_heads)
         if ret_labels is None:
             # todo(note): simply get labels from the go1-label classifier; must provide g1parser
             if go1_pack is None:
                 _, go1_pack = self._get_g1_pack(insts, 1., 1.)
             _, go1_label_max_idxes = go1_pack[1].max(
                 -1)  # [bs, slen, slen]
             pred_heads_arr, _ = self.predict_padder.pad(
                 all_heads)  # [bs, slen]
             pred_heads_expr = BK.input_idx(pred_heads_arr)
             pred_labels_expr = BK.gather_one_lastdim(
                 go1_label_max_idxes, pred_heads_expr).squeeze(-1)
             all_labels = BK.get_value(pred_labels_expr)  # [bs, slen]
         else:
             all_labels = np.concatenate(ret_labels, 0)
         # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
         for one_idx, one_inst in enumerate(insts):
             cur_length = len(one_inst) + 1
             one_inst.pred_heads.set_vals(
                 all_heads[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 all_labels[one_idx][:cur_length], self.label_vocab)
             # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length])
         # =====
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         return info
コード例 #23
0
ファイル: seqcrf.py プロジェクト: ValentinaPy/zmsp
 def predict(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
     conf = self.conf
     if self.add_root_token:
         repr_t = repr_t[:, 1:]
         mask_t = mask_t[:, 1:]
     # score
     scores_t = self._score(repr_t)  # [bs, rlen, D]
     # decode
     _, decode_idx = self._viterbi_decode(scores_t, mask_t.bool())
     decode_idx_arr = BK.get_value(decode_idx)  # [bs, rlen]
     for one_bidx, one_inst in enumerate(insts):
         one_pidxes = decode_idx_arr[one_bidx].tolist()[:len(one_inst)]
         one_pseq = SeqField(None)
         one_pseq.build_vals(one_pidxes, self.vocab)
         one_inst.add_item("pred_" + self.attr_name,
                           one_pseq,
                           assert_non_exist=False)
     return
コード例 #24
0
ファイル: head.py プロジェクト: ValentinaPy/zmsp
 def _exclude_nil(self,
                  sel_idxes,
                  sel_valid_mask,
                  sel_lab_idxes,
                  sel_lab_embeds,
                  sel_logprobs=None,
                  sel_items_arr=None):
     # todo(note): assure that nil is 0
     sel_valid_mask = sel_valid_mask * (sel_lab_idxes !=
                                        0).float()  # not inplaced
     # idx on idx
     s2_idxes, s2_valid_mask = BK.mask2idx(sel_valid_mask)
     sel_idxes = sel_idxes.gather(-1, s2_idxes)
     sel_valid_mask = s2_valid_mask
     sel_lab_idxes = sel_lab_idxes.gather(-1, s2_idxes)
     sel_lab_embeds = BK.gather_first_dims(sel_lab_embeds, s2_idxes, -2)
     sel_logprobs = None if sel_logprobs is None else sel_logprobs.gather(
         -1, s2_idxes)
     sel_items_arr = None if sel_items_arr is None \
         else sel_items_arr[np.arange(len(sel_items_arr))[:, np.newaxis], BK.get_value(s2_idxes)]
     return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs, sel_items_arr
コード例 #25
0
 def _decode(self, insts: List[ParseInstance], full_score, mask_expr,
             misc_prefix):
     # decode
     mst_lengths = [len(z) + 1
                    for z in insts]  # +=1 to include ROOT for mst decoding
     mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32)
     mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj(
         full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True)
     if self.conf.iconf.output_marginals:
         # todo(note): here, we care about marginals for arc
         # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True)
         arc_marginals = nmarginal_unproj(full_score,
                                          mask_expr,
                                          None,
                                          labeled=True).sum(-1)
         bsize, max_len = BK.get_shape(mask_expr)
         idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1)
         idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0)
         output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr,
                                     BK.input_idx(mst_heads_arr)]
         mst_marg_arr = BK.get_value(output_marg)
     else:
         mst_marg_arr = None
     # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
     for one_idx, one_inst in enumerate(insts):
         cur_length = mst_lengths[one_idx]
         one_inst.pred_heads.set_vals(
             mst_heads_arr[one_idx]
             [:cur_length])  # directly int-val for heads
         one_inst.pred_labels.build_vals(
             mst_labels_arr[one_idx][:cur_length], self.label_vocab)
         one_scores = mst_scores_arr[one_idx][:cur_length]
         one_inst.pred_par_scores.set_vals(one_scores)
         # extra output
         one_inst.extra_pred_misc[misc_prefix +
                                  "_score"] = one_scores.tolist()
         if mst_marg_arr is not None:
             one_inst.extra_pred_misc[
                 misc_prefix +
                 "_marg"] = mst_marg_arr[one_idx][:cur_length].tolist()
コード例 #26
0
ファイル: mtl.py プロジェクト: ValentinaPy/zmsp
 def _assign_attns_item(self, insts, prefix, input_erase_mask_arr=None, abs_posi_arr=None, cache=None):
     if cache is not None:
         attn_names, attn_list = [], []
         for one_sidx, one_attn in enumerate(cache.list_attn):
             attn_names.append(f"{prefix}_att{one_sidx}")
             attn_list.append(one_attn)
         if cache.accu_attn is not None:
             attn_names.append(f"{prefix}_att_accu")
             attn_list.append(cache.accu_attn)
         for one_name, one_attn in zip(attn_names, attn_list):
             # (step_idx, ) -> [bs, len_q, len_k, head]
             one_attn_arr = BK.get_value(one_attn)
             for bidx, inst in enumerate(insts):
                 save_arr = one_attn_arr[bidx]
                 inst.add_item(one_name, NpArrField(save_arr, float_decimal=4), assert_non_exist=False)
     if abs_posi_arr is not None:
         for bidx, inst in enumerate(insts):
             inst.add_item(f"{prefix}_abs_posi",
                           NpArrField(abs_posi_arr[bidx], float_decimal=0), assert_non_exist=False)
     if input_erase_mask_arr is not None:
         for bidx, inst in enumerate(insts):
             inst.add_item(f"{prefix}_erase_mask",
                           NpArrField(input_erase_mask_arr[bidx], float_decimal=4), assert_non_exist=False)
コード例 #27
0
ファイル: g2p.py プロジェクト: ValentinaPy/zmsp
 def loss(self, insts: List[ParseInstance], enc_expr, final_valid_expr,
          go1_pack, training: bool, margin: float):
     # first do decoding and related preparation
     with BK.no_grad_env():
         _, _, g_packs, p_packs = self.decode(insts, enc_expr,
                                              final_valid_expr, go1_pack,
                                              training, margin)
         # flatten the packs (remember to rebase the indexes)
         gold_pack = self._flatten_packs(g_packs)
         pred_pack = self._flatten_packs(p_packs)
         if self.filter_pruned:
             # filter out non-valid (pruned) edges, to avoid prune error
             mod_unpruned_mask, gold_mask = self.helper.get_unpruned_mask(
                 final_valid_expr, gold_pack)
             pred_mask = mod_unpruned_mask[
                 pred_pack[0], pred_pack[1]]  # filter by specific mod
             gold_pack = [(None if z is None else z[gold_mask])
                          for z in gold_pack]
             pred_pack = [(None if z is None else z[pred_mask])
                          for z in pred_pack]
     # calculate the scores for loss
     gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = gold_pack
     pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes = pred_pack
     gold_arc_score, gold_label_score_all = self._get_basic_score(
         enc_expr, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes,
         gold_gp_idxes)
     pred_arc_score, pred_label_score_all = self._get_basic_score(
         enc_expr, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes,
         pred_gp_idxes)
     # whether have labeled scores
     if self.system_labeled:
         gold_label_score = BK.gather_one_lastdim(
             gold_label_score_all, gold_lab_idxes).squeeze(-1)
         pred_label_score = BK.gather_one_lastdim(
             pred_label_score_all, pred_lab_idxes).squeeze(-1)
         ret_scores = (gold_arc_score, pred_arc_score, gold_label_score,
                       pred_label_score)
         pred_full_scores, gold_full_scores = pred_arc_score + pred_label_score, gold_arc_score + gold_label_score
     else:
         ret_scores = (gold_arc_score, pred_arc_score)
         pred_full_scores, gold_full_scores = pred_arc_score, gold_arc_score
     # hinge loss: filter-margin by loss*margin to be aware of search error
     if self.filter_margin:
         with BK.no_grad_env():
             mat_shape = BK.get_shape(enc_expr)[:2]  # [bs, slen]
             heads_gold = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                            gold_b_idxes, gold_m_idxes,
                                            gold_h_idxes)
             heads_pred = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                            pred_b_idxes, pred_m_idxes,
                                            pred_h_idxes)
             error_count = (heads_gold != heads_pred).float()
             if self.system_labeled:
                 labels_gold = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                                 gold_b_idxes, gold_m_idxes,
                                                 gold_lab_idxes)
                 labels_pred = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                                 pred_b_idxes, pred_m_idxes,
                                                 pred_lab_idxes)
                 error_count += (labels_gold != labels_pred).float()
             scores_gold = self._get_tmp_mat(mat_shape, 0., BK.float32,
                                             gold_b_idxes, gold_m_idxes,
                                             gold_full_scores)
             scores_pred = self._get_tmp_mat(mat_shape, 0., BK.float32,
                                             pred_b_idxes, pred_m_idxes,
                                             pred_full_scores)
             # todo(note): here, a small 0.1 is to exclude zero error: anyway they will get zero gradient
             sent_mask = ((scores_gold.sum(-1) - scores_pred.sum(-1)) <=
                          (margin * error_count.sum(-1) + 0.1)).float()
             num_valid_sent = float(BK.get_value(sent_mask.sum()))
         final_loss_sum = (
             pred_full_scores * sent_mask[pred_b_idxes] -
             gold_full_scores * sent_mask[gold_b_idxes]).sum()
     else:
         num_valid_sent = len(insts)
         final_loss_sum = (pred_full_scores - gold_full_scores).sum()
     # prepare final loss
     # divide loss by what?
     num_sent = len(insts)
     num_valid_tok = sum(len(z) for z in insts)
     if self.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "sent_valid": num_valid_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     return final_loss, ret_scores, info
コード例 #28
0
ファイル: head.py プロジェクト: ValentinaPy/zmsp
 def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.):
     conf = self.conf
     bsize = len(insts)
     # first get gold info, also multiple valid-masks
     gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h(
         insts)
     input_mask = input_mask * gold_valid.unsqueeze(-1)  # [*, slen]
     # step 1: selector
     if conf.use_selector:
         sel_loss, sel_mask = self.sel.loss(input_expr,
                                            input_mask,
                                            gold_masks,
                                            margin=margin)
     else:
         sel_loss, sel_mask = None, self._select_cands_training(
             input_mask, gold_masks, conf.train_min_rate)
     sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask)  # [*, max-count]
     sel_gold_idxes = gold_idxes.gather(-1, sel_idxes)
     sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes)
     # todo(+N): only get items by head position!
     _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value(
         sel_idxes)
     sel_items = gold_items_arr[_tmp_i0, _tmp_i1]  # [*, mc]
     sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1]
     # step 2: encoding and labeling
     # if we select nothing
     # ----- debug
     # zlog(f"fb-extractor 1: shape sel_idxes = {sel_idxes.shape}")
     # -----
     sel_shape = BK.get_shape(sel_idxes)
     if sel_shape[-1] == 0:
         lab_loss = [[BK.zeros([]), BK.zeros([])]]
         sel2_lab_loss = [[BK.zeros([]), BK.zeros([])]
                          ] if self.use_secondary_type else None
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim])
         ret_items = sel_items  # dim-1==0
     else:
         sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask,
                                   sel_idxes)  # [*, mc, DLab]
         lab_loss, sel_lab_idxes, sel_lab_embeds = self.hl.loss(
             sel_hid_exprs, sel_valid_mask, sel_gold_idxes, margin=margin)
         if conf.train_gold_corr:
             sel_lab_idxes = sel_gold_idxes
             if not self.hl.conf.use_lookup_soft:
                 sel_lab_embeds = self.hl.lookup(sel_lab_idxes)
         ret_items = sel_items
         # =====
         if self.use_secondary_type:
             sectype_embeds = self.t1tot2(sel_lab_idxes)  # [*, mc, D]
             if conf.sectype_noback_enc:
                 sel2_input = sel_hid_exprs.detach(
                 ) + sectype_embeds  # [*, mc, D]
             else:
                 sel2_input = sel_hid_exprs + sectype_embeds  # [*, mc, D]
             # =====
             # sepcial for the sectype mask (sample it within the gold ones)
             sel2_valid_mask = self._select_cands_training(
                 (sel_gold_idxes > 0).float(),
                 (sel_gold_idxes2 > 0).float(), conf.train_min_rate_s2)
             # =====
             sel2_lab_loss, sel2_lab_idxes, sel2_lab_embeds = self.hl.loss(
                 sel2_input,
                 sel2_valid_mask,
                 sel_gold_idxes2,
                 margin=margin)
             if conf.train_gold_corr:
                 sel2_lab_idxes = sel_gold_idxes2
                 if not self.hl.conf.use_lookup_soft:
                     sel2_lab_embeds = self.hl.lookup(sel2_lab_idxes)
             if conf.sectype_t2ift1:
                 sel2_lab_idxes = sel2_lab_idxes * (sel_lab_idxes > 0).long(
                 )  # pred t2 only if t1 is not 0 (nil)
             # combine the two
             if sel2_lab_idxes.sum().item(
             ) > 0:  # if there are any gold sectypes
                 ret_items = np.concatenate([ret_items, sel2_items],
                                            -1)  # [*, mc*2]
                 sel_idxes = BK.concat([sel_idxes, sel_idxes], -1)
                 sel_valid_mask = BK.concat(
                     [sel_valid_mask, sel2_valid_mask], -1)
                 sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes],
                                           -1)
                 sel_lab_embeds = BK.concat(
                     [sel_lab_embeds, sel2_lab_embeds], -2)
         else:
             sel2_lab_loss = None
         # =====
         # step 3: exclude nil and return
         if conf.exclude_nil:  # [*, mc', ...]
             sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \
                 self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items)
     # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2)  # [*, mc', D]
     # step 4: finally prepare loss and items
     for one_loss in lab_loss:
         one_loss[0] *= conf.lambda_ne
     ret_losses = lab_loss
     if sel2_lab_loss is not None:
         for one_loss in sel2_lab_loss:
             one_loss[0] *= conf.lambda_ne2
         ret_losses = ret_losses + sel2_lab_loss
     if sel_loss is not None:
         for one_loss in sel_loss:
             one_loss[0] *= conf.lambda_ns
         ret_losses = ret_losses + sel_loss
     # ----- debug
     # zlog(f"fb-extractor 2: shape sel_idxes = {sel_idxes.shape}")
     # -----
     # mask out invalid items with None
     ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None
     return ret_losses, ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
コード例 #29
0
 def fb(self, annotated_insts, scoring_expr_pack, training: bool,
        loss_factor: float):
     # depth constrain: <= sched_depth
     cur_depth_constrain = int(self.sched_depth.value)
     # run
     ags = [
         BfsLinearAgenda.init_agenda(TdState, z, self.require_sg)
         for z in annotated_insts
     ]
     self.oracle_manager.refresh_insts(annotated_insts)
     self.searcher.refresh(scoring_expr_pack)
     self.searcher.go(ags)
     # collect local loss: credit assignment
     if self.train_force or self.train_ss:
         states = []
         for ag in ags:
             for final_state in ag.local_golds:
                 # todo(warn): remember to use depth_eff rather than depth
                 # todo(warn): deprecated
                 # if final_state.depth_eff > cur_depth_constrain:
                 #     continue
                 states.append(final_state)
         logprobs_arc = [s.arc_score_slice for s in states]
         # no labeling scores for reduce operations
         logprobs_label = [
             s.label_score_slice for s in states
             if s.label_score_slice is not None
         ]
         credits_arc, credits_label = None, None
     elif self.train_of:
         states = []
         for ag in ags:
             for final_state in ag.ends:
                 for s in final_state.get_path(True):
                     states.append(s)
         logprobs_arc = [s.arc_score_slice for s in states]
         # no labeling scores for reduce operations
         logprobs_label = [
             s.label_score_slice for s in states
             if s.label_score_slice is not None
         ]
         credits_arc, credits_label = None, None
     elif self.train_rl:
         logprobs_arc, logprobs_label, credits_arc, credits_label = [], [], [], []
         for ag in ags:
             # todo(+2): need to check search failure?
             # todo(+2): ignoring labels when reducing or wrong-arc
             for final_state in ag.ends:
                 # todo(warn): deprecated
                 # if final_state.depth_eff > cur_depth_constrain:
                 #     continue
                 one_credits_arc = []
                 one_credits_label = []
                 self.oracle_manager.set_losses(final_state)
                 for s in final_state.get_path(True):
                     _, _, delta_arc, delta_label = s.oracle_loss_cache
                     logprobs_arc.append(s.arc_score_slice)
                     if delta_arc > 0:
                         # only blame arc
                         one_credits_arc.append(-delta_arc)
                     else:
                         one_credits_arc.append(0)
                         if delta_label > 0:
                             logprobs_label.append(s.label_score_slice)
                             one_credits_label.append(-delta_label)
                         elif s.label_score_slice is not None:
                             # not bad labeling
                             logprobs_label.append(s.label_score_slice)
                             one_credits_label.append(0)
                 # TODO(+N): minus average may encourage bad moves?
                 # balance
                 # avg_arc = sum(one_credits_arc) / len(one_credits_arc)
                 # avg_label = 0. if len(one_credits_label)==0 else sum(one_credits_label) / len(one_credits_label)
                 baseline_arc = baseline_label = -0.5
                 credits_arc.extend(z - baseline_arc
                                    for z in one_credits_arc)
                 credits_label.extend(z - baseline_label
                                      for z in one_credits_label)
     else:
         raise NotImplementedError("CANNOT get here!")
     # sum all local losses
     loss_zero = BK.zeros([])
     if len(logprobs_arc) > 0:
         batched_logprobs_arc = SliceManager.combine_slices(
             logprobs_arc, None)
         loss_arc = (-BK.sum(batched_logprobs_arc)) if (credits_arc is None) \
             else (-BK.sum(batched_logprobs_arc * BK.input_real(credits_arc)))
     else:
         loss_arc = loss_zero
     if len(logprobs_label) > 0:
         batched_logprobs_label = SliceManager.combine_slices(
             logprobs_label, None)
         loss_label = (-BK.sum(batched_logprobs_label)) if (credits_label is None) \
             else (-BK.sum(batched_logprobs_label*BK.input_real(credits_label)))
     else:
         loss_label = loss_zero
     final_loss_sum = loss_arc + loss_label
     # divide loss by what?
     num_sent = len(annotated_insts)
     num_valid_arcs, num_valid_labels = len(logprobs_arc), len(
         logprobs_label)
     # num_valid_steps = len(states)
     if self.tconf.loss_div_step:
         final_loss = loss_arc / max(1, num_valid_arcs) + loss_label / max(
             1, num_valid_labels)
     else:
         final_loss = final_loss_sum / num_sent
     #
     val_loss_arc = BK.get_value(loss_arc).item()
     val_loss_label = BK.get_value(loss_label).item()
     val_loss_sum = val_loss_arc + val_loss_label
     #
     cur_has_loss = 1 if ((num_valid_arcs + num_valid_labels) > 0) else 0
     if training and cur_has_loss:
         BK.backward(final_loss, loss_factor)
     # todo(warn): make tok==steps for dividing in common.run
     info = {
         "sent": num_sent,
         "tok": num_valid_arcs,
         "valid_arc": num_valid_arcs,
         "valid_label": num_valid_labels,
         "loss_sum": val_loss_sum,
         "loss_arc": val_loss_arc,
         "loss_label": val_loss_label,
         "fb_all": 1,
         "fb_valid": cur_has_loss
     }
     return info
コード例 #30
0
ファイル: g2p.py プロジェクト: ValentinaPy/zmsp
 def _decode(self, mb_insts: List[ParseInstance], mb_enc_expr,
             mb_valid_expr, mb_go1_pack, training: bool, margin: float):
     # =====
     use_sib, use_gp = self.use_sib, self.use_gp
     # =====
     mb_size = len(mb_insts)
     mat_shape = BK.get_shape(mb_valid_expr)
     max_slen = mat_shape[-1]
     # step 1: extract the candidate features
     batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes = self.helper.get_cand_features(
         mb_valid_expr)
     # =====
     # step 2: high order scoring
     # step 2.1: basic scoring, [*], [*, Lab]
     arc_scores, lab_scores = self._get_basic_score(mb_enc_expr,
                                                    batch_idxes, m_idxes,
                                                    h_idxes, sib_idxes,
                                                    gp_idxes)
     cur_system_labeled = (lab_scores is not None)
     # step 2.2: margin
     # get gold labels, which can be useful for later calculating loss
     if training:
         gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = \
             [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_insts(mb_insts, use_sib, use_gp)]
         # add the margins to the scores: (m,h), (m,sib), (m,gp)
         cur_margin = margin / self.margin_div
         self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes,
                                   gold_h_idxes, gold_lab_idxes,
                                   batch_idxes, m_idxes, h_idxes,
                                   arc_scores, lab_scores, cur_margin,
                                   cur_margin)
         if use_sib:
             self._add_margin_inplaced(mat_shape, gold_b_idxes,
                                       gold_m_idxes, gold_sib_idxes,
                                       gold_lab_idxes, batch_idxes, m_idxes,
                                       sib_idxes, arc_scores, lab_scores,
                                       cur_margin, cur_margin)
         if use_gp:
             self._add_margin_inplaced(mat_shape, gold_b_idxes,
                                       gold_m_idxes, gold_gp_idxes,
                                       gold_lab_idxes, batch_idxes, m_idxes,
                                       gp_idxes, arc_scores, lab_scores,
                                       cur_margin, cur_margin)
         # may be useful for later training
         gold_pack = (mb_size, gold_b_idxes, gold_m_idxes, gold_h_idxes,
                      gold_sib_idxes, gold_gp_idxes, gold_lab_idxes)
     else:
         gold_pack = None
     # step 2.3: o1scores
     if mb_go1_pack is not None:
         go1_arc_scores, go1_lab_scores = mb_go1_pack
         # todo(note): go1_arc_scores is not added here, but as the input to the dec-algo
         if cur_system_labeled:
             lab_scores += go1_lab_scores[batch_idxes, m_idxes, h_idxes]
     else:
         go1_arc_scores = None
     # step 2.4: max out labels; todo(+N): or using logsumexp here?
     if cur_system_labeled:
         max_lab_scores, max_lab_idxes = lab_scores.max(-1)
         final_scores = arc_scores + max_lab_scores  # [*], final input arc scores
     else:
         max_lab_idxes = None
         final_scores = arc_scores
     # =====
     # step 3: actual decode
     res_heads = []
     for sid, inst in enumerate(mb_insts):
         slen = len(inst) + 1  # plus one for the art-root
         arr_o1_masks = BK.get_value(mb_valid_expr[sid, :slen, :slen].int())
         arr_o1_scores = BK.get_value(
             go1_arc_scores[sid, :slen, :slen].double()) if (
                 go1_arc_scores is not None) else None
         cur_bidx_mask = (batch_idxes == sid)
         input_pack = [m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores]
         one_heads = self.helper.decode_one(slen, self.projective,
                                            arr_o1_masks, arr_o1_scores,
                                            input_pack, cur_bidx_mask)
         res_heads.append(one_heads)
     # =====
     # step 4: get labels back and pred_pack
     pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, _ = \
         [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_preds(res_heads, None, use_sib, use_gp)]
     if cur_system_labeled:
         # obtain hit components
         pred_hit_mask = self._get_hit_mask(mat_shape, pred_b_idxes,
                                            pred_m_idxes, pred_h_idxes,
                                            batch_idxes, m_idxes, h_idxes)
         if use_sib:
             pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes,
                                                 pred_m_idxes,
                                                 pred_sib_idxes,
                                                 batch_idxes, m_idxes,
                                                 sib_idxes)
         if use_gp:
             pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes,
                                                 pred_m_idxes,
                                                 pred_gp_idxes, batch_idxes,
                                                 m_idxes, gp_idxes)
         # get pred labels (there should be only one hit per mod!)
         pred_labels = BK.constants_idx([mb_size, max_slen], 0)
         pred_labels[batch_idxes[pred_hit_mask],
                     m_idxes[pred_hit_mask]] = max_lab_idxes[pred_hit_mask]
         res_labels = BK.get_value(pred_labels)
         pred_lab_idxes = pred_labels[pred_b_idxes, pred_m_idxes]
     else:
         res_labels = None
         pred_lab_idxes = None
     pred_pack = (mb_size, pred_b_idxes, pred_m_idxes, pred_h_idxes,
                  pred_sib_idxes, pred_gp_idxes, pred_lab_idxes)
     # return
     return res_heads, res_labels, gold_pack, pred_pack