Beispiel #1
0
 def forward_features(self, ids_expr, mask_expr, typeids_expr,
                      other_embed_exprs: List):
     bmodel = self.model
     bmodel_embedding = bmodel.embeddings
     bmodel_encoder = bmodel.encoder
     # prepare
     attention_mask = mask_expr
     token_type_ids = BK.zeros(BK.get_shape(
         ids_expr)).long() if typeids_expr is None else typeids_expr
     extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
     # extended_attention_mask = extended_attention_mask.to(dtype=next(bmodel.parameters()).dtype)  # fp16 compatibility
     extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
     # embeddings
     cur_layer = 0
     if self.trainable_min_layer <= 0:
         last_output = bmodel_embedding(ids_expr,
                                        position_ids=None,
                                        token_type_ids=token_type_ids)
     else:
         with BK.no_grad_env():
             last_output = bmodel_embedding(ids_expr,
                                            position_ids=None,
                                            token_type_ids=token_type_ids)
     # extra embeddings (this implies overall graident requirements!!)
     for one_eidx, one_embed in enumerate(self.other_embeds):
         last_output += one_embed(
             other_embed_exprs[one_eidx])  # [bs, slen, D]
     # =====
     all_outputs = []
     if self.layer_is_output[cur_layer]:
         all_outputs.append(last_output)
     cur_layer += 1
     # todo(note): be careful about the indexes!
     # not-trainable encoders
     trainable_min_layer_idx = max(0, self.trainable_min_layer - 1)
     with BK.no_grad_env():
         for layer_module in bmodel_encoder.layer[:trainable_min_layer_idx]:
             last_output = layer_module(last_output,
                                        extended_attention_mask, None)[0]
             if self.layer_is_output[cur_layer]:
                 all_outputs.append(last_output)
             cur_layer += 1
     # trainable encoders
     for layer_module in bmodel_encoder.layer[trainable_min_layer_idx:self.
                                              output_max_layer]:
         last_output = layer_module(last_output, extended_attention_mask,
                                    None)[0]
         if self.layer_is_output[cur_layer]:
             all_outputs.append(last_output)
         cur_layer += 1
     assert cur_layer == self.output_max_layer + 1
     # stack
     if len(all_outputs) == 1:
         ret_expr = all_outputs[0].unsqueeze(-2)
     else:
         ret_expr = BK.stack(all_outputs, -2)  # [BS, SLEN, LAYER, D]
     final_ret_exp = self.output_f(ret_expr)
     return final_ret_exp
Beispiel #2
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     iconf = self.conf.iconf
     pconf = iconf.pruning_conf
     with BK.no_grad_env():
         self.refresh_batch(False)
         if iconf.use_pruning:
             # todo(note): for the testing of pruning mode, use the scores instead
             if self.g1_use_aux_scores:
                 valid_mask, arc_score, label_score, mask_expr, _ = G1Parser.score_and_prune(
                     insts, self.num_label, pconf)
             else:
                 valid_mask, arc_score, label_score, mask_expr, _ = self.prune_on_batch(
                     insts, pconf)
             valid_mask_f = valid_mask.float()  # [*, len, len]
             mask_value = Constants.REAL_PRAC_MIN
             full_score = arc_score.unsqueeze(-1) + label_score
             full_score += (mask_value * (1. - valid_mask_f)).unsqueeze(-1)
             info_pruning = G1Parser.collect_pruning_info(
                 insts, valid_mask_f)
             jpos_pack = [None, None, None]
         else:
             input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
                 insts, False)
             mask_expr = BK.input_real(mask_arr)
             full_score = self.scorer_helper.score_full(enc_repr)
             info_pruning = None
         # =====
         self._decode(insts, full_score, mask_expr, "g1")
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         if info_pruning is not None:
             info.update(info_pruning)
         return info
Beispiel #3
0
 def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf,
              vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.conf = conf
     # vocab and padder
     self.word_vocab = vpack.get_voc("word")
     self.padder = DataPadder(
         2, pad_vals=self.word_vocab.pad,
         mask_range=2)  # todo(note): <pad>-id is very large
     # models
     self.hid_layer = self.add_sub_node(
         "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act))
     self.pred_layer = self.add_sub_node(
         "pred",
         Affine(pc,
                conf.hid_dim,
                conf.max_pred_rank + 1,
                init_rop=NoDropRop()))
     if conf.init_pred_from_pretrain:
         npvec = vpack.get_emb("word")
         if npvec is None:
             zwarn(
                 "Pretrained vector not provided, skip init pred embeddings!!"
             )
         else:
             with BK.no_grad_env():
                 self.pred_layer.ws[0].copy_(
                     BK.input_real(npvec[:conf.max_pred_rank + 1].T))
             zlog(
                 f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})."
             )
Beispiel #4
0
 def inference_on_batch(self, insts: List[DocInstance], **kwargs):
     self.refresh_batch(False)
     # -----
     if len(insts) == 0:
         return {}
     # -----
     # todo(note): first do shallow copy!
     for one_doc in insts:
         for one_sent in one_doc.sents:
             one_sent.pred_entity_fillers = [
                 z for z in one_sent.entity_fillers
             ]
             one_sent.pred_events = [
                 shallow_copy(z) for z in one_sent.events
             ]
     # -----
     ndoc, nsent = len(insts), 0
     iconf = self.conf.iconf
     with BK.no_grad_env():
         # splitting into buckets
         all_packs = self.bter.run(insts, training=False)
         for one_pack in all_packs:
             ms_items, bert_expr, basic_expr = one_pack
             nsent += len(ms_items)
             self.predictor.predict(ms_items, bert_expr, basic_expr)
     info = {
         "doc": ndoc,
         "sent": nsent,
         "num_evt": sum(len(z.pred_events) for z in insts)
     }
     if iconf.decode_verbose:
         zlog(f"Decode one mini-batch: {info}")
     return info
Beispiel #5
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     with BK.no_grad_env():
         self.refresh_batch(False)
         # encode
         enc_expr, mask_expr = self.enc.run(
             insts, False,
             input_word_mask_repl=None)  # no masklm in testing
         # decode
         self.dec.predict(insts, enc_expr, mask_expr)
         # =====
         # test for masklm
         input_word_mask_repl_arr, output_pred_mask_repl_arr, ouput_pred_idx_arr = self.masklm.prepare(
             insts, False)
         enc_expr2, mask_expr2 = self.enc.run(
             insts, False, input_word_mask_repl=input_word_mask_repl_arr)
         masklm_loss = self.masklm.loss(enc_expr2,
                                        output_pred_mask_repl_arr,
                                        ouput_pred_idx_arr)
         masklm_loss_val, masklm_loss_count, masklm_corr_count = [
             z.item() for z in masklm_loss[0]
         ]
         # =====
     info = {
         "sent": len(insts),
         "tok": sum(map(len, insts)),
         "masklm_loss_val": masklm_loss_val,
         "masklm_loss_count": masklm_loss_count,
         "masklm_corr_count": masklm_corr_count
     }
     return info
Beispiel #6
0
def nmst_greedy(scores_expr,
                mask_expr,
                lengths_arr,
                labeled=True,
                ret_arr=False):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # mask out diag
        scores_expr += BK.diagflat(
            BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1)
        # combined last two dimension and Max over them
        combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1])
        combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr,
                                                        dim=-1)
        # back to real idxes
        last_size = scores_shape[-1]
        greedy_heads = combined_max_idxes // last_size
        greedy_labels = combined_max_idxes % last_size
        if ret_arr:
            mst_heads_arr, mst_labels_arr, mst_scores_arr = [
                BK.get_value(z)
                for z in (greedy_heads, greedy_labels, combine_max_scores)
            ]
            return mst_heads_arr, mst_labels_arr, mst_scores_arr
        else:
            return greedy_heads, greedy_labels, combine_max_scores
Beispiel #7
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     _LEQ_MAP = lambda x, m: [max(m, z) for z in x]  # larger or equal map
     #
     with BK.no_grad_env():
         self.refresh_batch(False)
         scoring_prep_f = lambda ones: self._prepare_score(ones, False)[0]
         get_best_state_f = lambda ag: sorted(ag.ends, key=lambda s: s.score_accu, reverse=True)[0]
         ags, info = self.inferencer.decode(insts, scoring_prep_f, False)
         # put the results inplaced
         for one_inst, one_ag in zip(insts, ags):
             best_state = get_best_state_f(one_ag)
             one_inst.pred_heads.set_vals(_LEQ_MAP(best_state.list_arc, 0))  # directly int-val for heads
             # todo(warn): already the correct labels, no need to transform
             # -- one_inst.pred_labels.build_vals(self.pred2real_labels(best_state.list_label), self.label_vocab)
             one_inst.pred_labels.build_vals(_LEQ_MAP(best_state.list_label, 1), self.label_vocab)
             # todo(warn): add children ordering in MISC field
             one_inst.pred_miscs.set_vals(self._get_chs_ordering(best_state))
         # check search err
         if self.conf.iconf.check_serr:
             ags_fo, _ = self.fber.force_decode(insts, scoring_prep_f, False)
             serr_sent = 0
             serr_tok = 0
             for ag, ag_fo in zip(ags, ags_fo):
                 best_state = get_best_state_f(ag)
                 best_fo_state = get_best_state_f(ag_fo)
                 # for this one, only care about UAS
                 if best_state.score_accu < best_fo_state.score_accu:
                     cur_serr_tok = sum((1 if a!=b else 0) for a,b in zip(best_state.list_arc[1:], best_fo_state.list_arc[1:]))
                     if cur_serr_tok > 0:
                         serr_sent += 1
                         serr_tok += cur_serr_tok
             info["serr_sent"] = serr_sent
             info["serr_tok"] = serr_tok
         return info
Beispiel #8
0
 def score_on_batch(self, insts: List[ParseInstance]):
     with BK.no_grad_env():
         self.refresh_batch(False)
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
             insts, False)
         # mask_expr = BK.input_real(mask_arr)
         arc_score = self.scorer_helper.score_arc(enc_repr)
         label_score = self.scorer_helper.score_label(enc_repr)
         return arc_score.squeeze(-1), label_score
Beispiel #9
0
 def score_on_batch(self, insts: List[ParseInstance]):
     with BK.no_grad_env():
         self.refresh_batch(False)
         scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
             insts, False)
         full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                               False, 0.)
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, False, 0.)
         return full_arc_score, full_label_score
Beispiel #10
0
 def inference_on_batch(self, insts: List[GeneralSentence], **kwargs):
     conf = self.conf
     self.refresh_batch(False)
     # print(f"{len(insts)}: {insts[0].sid}")
     with BK.no_grad_env():
         # decode for dpar
         input_map = self.model.inputter(insts)
         emb_t, mask_t, enc_t, cache, _ = self.model._emb_and_enc(input_map, collect_loss=False)
         input_t = BK.concat(cache.list_attn, -1)  # [bs, slen, slen, L*H]
         self.dpar.predict(insts, BK.zeros([1,1]), input_t, mask_t)
     return {}
Beispiel #11
0
 def score_and_select(self, enc_expr, pad_mask):
     conf = self.conf
     if conf.ns_no_back:
         enc_expr = enc_expr.detach()
     all_scores = self.scorer(enc_expr).squeeze(-1)  # [*, slen]
     # only for getting the mask
     with BK.no_grad_env():
         masked_all_scores = all_scores + (
             1. - pad_mask) * Constants.REAL_PRAC_MIN
         res_mask = self._select_topk(masked_all_scores, pad_mask, pad_mask,
                                      conf.topk_ratio, conf.thresh_k)
     return res_mask, all_scores
Beispiel #12
0
 def inference_on_batch(self, insts: List[GeneralSentence], **kwargs):
     conf = self.conf
     self.refresh_batch(False)
     with BK.no_grad_env():
         # special mode
         # use: CUDA_VISIBLE_DEVICES=3 PYTHONPATH=../../src/ python3 -m pdb ../../src/tasks/cmd.py zmlm.main.test ${RUN_DIR}/_conf device:0 dict_dir:${RUN_DIR}/ model_load_name:${RUN_DIR}/zmodel.best test:./_en.debug test_interactive:1
         if conf.test_interactive:
             iinput_sent = input(">> (Interactive testing) Input sent sep by blanks: ")
             iinput_tokens = iinput_sent.split()
             if len(iinput_sent) > 0:
                 iinput_inst = GeneralSentence.create(iinput_tokens)
                 iinput_inst.word_seq.set_idxes([self.word_vocab.get_else_unk(w) for w in iinput_inst.word_seq.vals])
                 iinput_inst.char_seq.build_idxes(self.inputter.vpack.get_voc("char"))
                 iinput_map = self.inputter([iinput_inst])
                 iinput_erase_mask = np.asarray([[z=="Z" for z in iinput_tokens]]).astype(dtype=np.float32)
                 iinput_masked_map = self.inputter.mask_input(iinput_map, iinput_erase_mask, set("pos"))
                 emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(iinput_masked_map, collect_loss=False, insts=[iinput_inst])
                 mlm_loss = self.masklm.loss(enc_t, iinput_erase_mask, iinput_map)
                 dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1)))
                 self.dpar.predict([iinput_inst], enc_t, dpar_input_attn, mask_t)
                 self.upos.predict([iinput_inst], enc_t, mask_t)
                 # print them
                 import pandas as pd
                 cur_fields = {
                     "idxes": list(range(1, len(iinput_inst)+1)),
                     "word": iinput_inst.word_seq.vals, "pos": iinput_inst.pred_pos_seq.vals,
                     "head": iinput_inst.pred_dep_tree.heads[1:], "dlab": iinput_inst.pred_dep_tree.labels[1:]}
                 zlog(f"Result:\n{pd.DataFrame(cur_fields).to_string()}")
             return {}  # simply return here for interactive mode
         # -----
         # test for MLM simply as in training (use special separate rand_gen to keep the masks the same for testing)
         # todo(+2): do we need to keep testing/validing during training the same? Currently not!
         info = self.fb_on_batch(insts, training=False, rand_gen=self.testing_rand_gen, assign_attns=conf.testing_get_attns)
         # -----
         if len(insts) == 0:
             return info
         # decode for dpar
         input_map = self.inputter(insts)
         emb_t, mask_t, enc_t, cache, _ = self._emb_and_enc(input_map, collect_loss=False, insts=insts)
         dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1)))
         self.dpar.predict(insts, enc_t, dpar_input_attn, mask_t)
         self.upos.predict(insts, enc_t, mask_t)
         if self.ner is not None:
             self.ner.predict(insts, enc_t, mask_t)
         # -----
         if conf.testing_get_attns:
             if conf.enc_choice == "vrec":
                 self._assign_attns_item(insts, "orig", cache=cache)
             elif conf.enc_choice in ["original"]:
                 pass
             else:
                 raise NotImplementedError()
         return info
Beispiel #13
0
 def inference_on_batch(self, insts: List[DocInstance], **kwargs):
     self.refresh_batch(False)
     # -----
     if len(insts) == 0:
         return {}
     # -----
     ndoc, nsent = len(insts), 0
     iconf = self.conf.iconf
     # =====
     # get tmp ms_items for each event
     input_ms_items = self._insts2msitems(insts)
     # -----
     if len(input_ms_items) == 0:
         return {}
     # -----
     with BK.no_grad_env():
         # splitting into buckets
         all_packs = self.bter.run(input_ms_items, training=False)
         for one_pack in all_packs:
             ms_items, bert_expr, basic_expr = one_pack
             nsent += len(ms_items)
             # cands
             if iconf.lookup_ef:
                 self._lookup_efs(ms_items)
             else:
                 self.cand_extractor.predict(ms_items, bert_expr,
                                             basic_expr)
             # args
             if iconf.pred_arg:
                 self.arg_linker.predict(ms_items, bert_expr, basic_expr)
             # span
             if iconf.pred_span:
                 self.span_expander.predict(ms_items, bert_expr)
     # put back all predictions
     self._putback_preds(input_ms_items)
     # collect all stats
     num_ef, num_evt, num_arg = 0, 0, 0
     for one_doc in insts:
         for one_sent in one_doc.sents:
             num_ef += len(one_sent.pred_entity_fillers)
             num_evt += len(one_sent.pred_events)
             num_arg += sum(len(z.links) for z in one_sent.pred_events)
     info = {
         "doc": ndoc,
         "sent": nsent,
         "num_ef": num_ef,
         "num_evt": num_evt,
         "num_arg": num_arg
     }
     if iconf.decode_verbose:
         zlog(f"Decode one mini-batch: {info}")
     return info
Beispiel #14
0
 def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1.,
                 rand_gen=None, assign_attns=False, **kwargs):
     self.refresh_batch(training)
     # get inputs with models
     with BK.no_grad_env():
         input_map = self.model.inputter(insts)
         emb_t, mask_t, enc_t, cache, enc_loss = self.model._emb_and_enc(input_map, collect_loss=True)
         input_t = BK.concat(cache.list_attn, -1)  # [bs, slen, slen, L*H]
     losses = [self.dpar.loss(insts, BK.zeros([1,1]), input_t, mask_t)]
     # -----
     info = self.collect_loss_and_backward(losses, training, loss_factor)
     info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)})
     return info
Beispiel #15
0
 def prune_on_batch(self, insts: List[ParseInstance], pconf: PruneG1Conf):
     with BK.no_grad_env():
         self.refresh_batch(False)
         # encode
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
             insts, False)
         mask_expr = BK.input_real(mask_arr)
         arc_score = self.scorer_helper.score_arc(enc_repr)
         label_score = self.scorer_helper.score_label(enc_repr)
         final_valid_mask, arc_marginals = G1Parser.prune_with_scores(
             arc_score, label_score, mask_expr, pconf)
         return final_valid_mask, arc_score.squeeze(
             -1), label_score, mask_expr, arc_marginals
Beispiel #16
0
 def loss(self, enc_expr, pad_mask, gold_mask, margin: float):
     conf = self.conf
     # =====
     # first testing-mode scoring and selecting
     res_mask, all_scores = self.score_and_select(enc_expr, pad_mask)
     # add gold
     if conf.ns_add_gold:
         res_mask += gold_mask
         res_mask.clamp_(max=1.)
     # =====
     with BK.no_grad_env():
         # how to select instances for training
         if conf.train_ratio2gold > 0.:
             # use gold-ratio for training
             masked_all_scores = all_scores + (
                 1. - pad_mask + gold_mask) * Constants.REAL_PRAC_MIN
             loss_mask = self._select_topk(masked_all_scores, pad_mask,
                                           gold_mask, conf.train_ratio2gold,
                                           None)
             loss_mask += gold_mask
             loss_mask.clamp_(max=1.)
         elif not conf.ns_add_gold:
             loss_mask = res_mask + gold_mask
             loss_mask.clamp_(max=1.)
         else:
             # we already have the gold
             loss_mask = res_mask
     # ===== calculating losses [*, L]
     # first aug scores by margin
     aug_scores = all_scores - (conf.margin_pos * margin) * gold_mask + (
         conf.margin_neg * margin) * (1. - gold_mask)
     if self.loss_hinge:
         # multiply pos instances with -1
         flipped_scores = aug_scores * (1. - 2 * gold_mask)
         losses_all = BK.clamp(flipped_scores, min=0.)
     elif self.loss_prob:
         losses_all = BK.binary_cross_entropy_with_logits(aug_scores,
                                                          gold_mask,
                                                          reduction='none')
         if conf.no_loss_satisfy_margin:
             unsatisfy_mask = ((aug_scores * (1. - 2 * gold_mask)) >
                               0.).float()  # those still with hinge loss
             losses_all *= unsatisfy_mask
     else:
         raise NotImplementedError()
     # return prediction and loss(sum/count)
     loss_sum = (losses_all * loss_mask).sum()
     if conf.train_return_loss_mask:
         return [[loss_sum, loss_mask.sum()]], loss_mask
     else:
         return [[loss_sum, loss_mask.sum()]], res_mask
Beispiel #17
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     with BK.no_grad_env():
         self.refresh_batch(False)
         # encode
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, False)
         # g1 score
         g1_pack = self._get_g1_pack(insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing)
         # decode for parsing
         self.inferencer.decode(insts, enc_repr, mask_arr, g1_pack, self.label_vocab)
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         return info
Beispiel #18
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     # iconf = self.conf.iconf
     with BK.no_grad_env():
         self.refresh_batch(False)
         full_score, _, jpos_pack, mask_expr, _, _ = \
             self._score(insts, False, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing)
         # collect the results together
         # =====
         self._decode(insts, full_score, mask_expr, "s2")
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         return info
Beispiel #19
0
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr, dim=-1)  # [BS, m, h]
        # marginal for unlabeled
        scores_unlabeled_arr = BK.get_value(scores_unlabeled)
        marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr,
                                                lengths_arr, False)
        # back to labeled values
        marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr)
        marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(
            -1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1))
        # [BS, m, h, L]
        return _ensure_margins_norm(marginals_labeled_expr)
Beispiel #20
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     # iconf = self.conf.iconf
     with BK.no_grad_env():
         self.refresh_batch(False)
         # pruning and scores from g1
         valid_mask, go1_pack = self._get_g1_pack(
             insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing)
         # encode
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
             insts, False)
         mask_expr = BK.input_real(mask_arr)
         # decode
         final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
         ret_heads, ret_labels, _, _ = self.dl.decode(
             insts, enc_repr, final_valid_expr, go1_pack, False, 0.)
         # collect the results together
         all_heads = Helper.join_list(ret_heads)
         if ret_labels is None:
             # todo(note): simply get labels from the go1-label classifier; must provide g1parser
             if go1_pack is None:
                 _, go1_pack = self._get_g1_pack(insts, 1., 1.)
             _, go1_label_max_idxes = go1_pack[1].max(
                 -1)  # [bs, slen, slen]
             pred_heads_arr, _ = self.predict_padder.pad(
                 all_heads)  # [bs, slen]
             pred_heads_expr = BK.input_idx(pred_heads_arr)
             pred_labels_expr = BK.gather_one_lastdim(
                 go1_label_max_idxes, pred_heads_expr).squeeze(-1)
             all_labels = BK.get_value(pred_labels_expr)  # [bs, slen]
         else:
             all_labels = np.concatenate(ret_labels, 0)
         # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
         for one_idx, one_inst in enumerate(insts):
             cur_length = len(one_inst) + 1
             one_inst.pred_heads.set_vals(
                 all_heads[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 all_labels[one_idx][:cur_length], self.label_vocab)
             # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length])
         # =====
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         return info
Beispiel #21
0
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr,
                                                 lengths_arr,
                                                 labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax,
                                                mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(
                mst_scores_arr)
Beispiel #22
0
 def inference_on_batch(self, insts: List[DocInstance], **kwargs):
     self.refresh_batch(False)
     test_constrain_evt_types = self.test_constrain_evt_types
     ndoc, nsent = len(insts), 0
     iconf = self.conf.iconf
     with BK.no_grad_env():
         # splitting into buckets
         all_packs = self.bter.run(insts, training=False)
         for one_pack in all_packs:
             # =====
             # predict
             sent_insts, lexi_repr, enc_repr_ef, enc_repr_evt, mask_arr = one_pack
             nsent += len(sent_insts)
             mask_expr = BK.input_real(mask_arr)
             # entity and filler
             if iconf.lookup_ef:
                 ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \
                     self._lookup_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, self.ef_extractor, ret_copy=True)
             elif iconf.pred_ef:
                 ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \
                     self._inference_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, self.ef_extractor, self.ef_creator)
             else:
                 ef_items = [[] for _ in range(len(sent_insts))]
                 ef_valid_mask = BK.zeros((len(sent_insts), 0))
                 ef_widxes = ef_lab_idxes = ef_lab_embeds = None
             # event
             if iconf.lookup_evt:
                 evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \
                     self._lookup_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, self.evt_extractor, ret_copy=True)
             else:
                 evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \
                     self._inference_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, self.evt_extractor, self.evt_creator)
             # arg
             if iconf.pred_arg:
                 # todo(note): for this step of decoding, we only consider inner-sentence pairs
                 # todo(note): inplaced
                 self._inference_args(ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef,
                                      evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt)
             # =====
             # assign
             for one_sent_inst, one_ef_items, one_ef_valid, one_evt_items, one_evt_valid in \
                     zip(sent_insts, ef_items, BK.get_value(ef_valid_mask), evt_items, BK.get_value(evt_valid_mask)):
                 # entity and filler
                 one_ef_items = [z for z,va in zip(one_ef_items, one_ef_valid) if (va and z is not None)]
                 one_sent_inst.pred_entity_fillers = one_ef_items
                 # event
                 one_evt_items = [z for z,va in zip(one_evt_items, one_evt_valid) if (va and z is not None)]
                 if test_constrain_evt_types is not None:
                     one_evt_items = [z for z in one_evt_items if z.type in test_constrain_evt_types]
                 # =====
                 # todo(note): special rule (actually a simple rule based extender)
                 if iconf.expand_evt_compound:
                     for one_evt in one_evt_items:
                         one_hard_span = one_evt.mention.hard_span
                         sid, hwid, _ = one_hard_span.position(True)
                         assert one_hard_span.length == 1  # currently no way to predict more
                         if hwid+1 < one_sent_inst.length:
                             if one_sent_inst.uposes.vals[hwid]=="VERB" and one_sent_inst.ud_heads.vals[hwid+1]==hwid \
                                     and one_sent_inst.ud_labels.vals[hwid+1]=="compound":
                                 one_hard_span.length += 1
                 # =====
                 one_sent_inst.pred_events = one_evt_items
     return {"doc": ndoc, "sent": nsent}
Beispiel #23
0
 def fb_on_batch(self,
                 annotated_insts: List[DocInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     assert self.train_constrain_evt_types is None, "Not implemented for training constrain_types"
     ndoc, nsent = len(annotated_insts), 0
     lambda_cand, lambda_arg, lambda_span = self.lambda_cand.value, self.lambda_arg.value, self.lambda_span.value
     # simply multi-task training, no explict interactions between them
     # -----
     # only include the ones with enough annotations
     valid_insts = []
     for one_inst in annotated_insts:
         if all(z.events is not None and z.entity_fillers is not None
                for z in one_inst.sents):
             valid_insts.append(one_inst)
     input_ms_items = self._insts2msitems(valid_insts)
     if len(input_ms_items) == 0:
         return {}
     # -----
     all_packs = self.bter.run(input_ms_items, training=training)
     all_cand_losses, all_arg_losses, all_span_losses = [], [], []
     # =====
     mix_pred_ef_rate = self.conf.mix_pred_ef_rate
     mix_pred_ef_rate_outside = self.conf.mix_pred_ef_rate_outside
     mix_pred_ef_count = 0
     # cur_margin = self.margin.value  # todo(+N): currently not used!
     for one_pack in all_packs:
         ms_items, bert_expr, basic_expr = one_pack
         nsent += len(ms_items)
         if lambda_cand > 0.:
             # todo(note): no need to clean up pred ones since sentences and containers are copied
             cand_losses = self.cand_extractor.loss(ms_items, bert_expr,
                                                    basic_expr)
             all_cand_losses.append(cand_losses)
             # predict as candidates
             with BK.no_grad_env():
                 self.cand_extractor.predict(ms_items, bert_expr,
                                             basic_expr)
                 # mix into gold ones; no need to cleanup since these are copies by _insts2msitems
                 for one_msent in ms_items:
                     center_idx = one_msent.center_idx
                     for one_sidx, one_sent in enumerate(one_msent.sents):
                         hit_posi = set()
                         for one_ef in one_sent.entity_fillers:
                             posi = one_ef.mention.hard_span.position()
                             hit_posi.add(posi)
                         # add predicted ones
                         cur_mix_rate = mix_pred_ef_rate if (
                             center_idx
                             == one_sidx) else mix_pred_ef_rate_outside
                         for one_ef in one_sent.pred_entity_fillers:
                             posi = one_ef.mention.hard_span.position()
                             if posi not in hit_posi and next(
                                     self.random_sample_stream
                             ) <= cur_mix_rate:
                                 hit_posi.add(posi)
                                 # one_ef.is_mix = True
                                 # todo(note): these are not TRUE efs, but only mixing preds as neg examples for training
                                 one_sent.entity_fillers.append(one_ef)
                                 mix_pred_ef_count += 1
         if lambda_arg > 0.:
             # todo(note): since currently we are predicting all candidates for one event
             arg_losses = self.arg_linker.loss(ms_items,
                                               bert_expr,
                                               basic_expr,
                                               dynamic_prepare=True)
             all_arg_losses.append(arg_losses)
         if lambda_span > 0.:
             span_losses = self.span_expander.loss(ms_items, bert_expr)
             all_span_losses.append(span_losses)
     # =====
     # final loss sum and backward
     info = {
         "doc": ndoc,
         "sent": nsent,
         "fb": 1,
         "mix_pef": mix_pred_ef_count
     }
     if len(all_packs) > 0:
         self.collect_loss_and_backward(
             ["cand", "arg", "span"],
             [all_cand_losses, all_arg_losses, all_span_losses],
             [lambda_cand, lambda_arg, lambda_span], info, training,
             loss_factor)
     return info
Beispiel #24
0
 def __init__(self, pc: BK.ParamCollection, bconf: Berter2Conf):
     super().__init__(pc, None, None)
     self.bconf = bconf
     self.model_name = bconf.bert2_model
     zlog(
         f"Loading pre-trained bert model for Berter2 of {self.model_name}")
     # Load pretrained model/tokenizer
     self.tokenizer = BertTokenizer.from_pretrained(
         self.model_name,
         do_lower_case=bconf.bert2_lower_case,
         cache_dir=None if
         (not bconf.bert2_cache_dir) else bconf.bert2_cache_dir)
     self.model = BertModel.from_pretrained(
         self.model_name,
         output_hidden_states=True,
         cache_dir=None if
         (not bconf.bert2_cache_dir) else bconf.bert2_cache_dir)
     zlog(f"Load done, move to default device {BK.DEFAULT_DEVICE}")
     BK.to_device(self.model)
     # =====
     # zero padding embeddings?
     if bconf.bert2_zero_pademb:
         with BK.no_grad_env():
             # todo(warn): specific!!
             zlog(
                 f"Unusual operation: make bert's padding embedding (idx0) zero!!"
             )
             self.model.embeddings.word_embeddings.weight[0].fill_(0.)
     # =====
     # check trainable ones and add parameters
     # todo(+N): this part is specific and looking into the lib, can break in further versions!!
     # the idx of layer is [1(embed)] + [N(enc)], that is, layer0 is the output of embeddings
     self.hidden_size = self.model.config.hidden_size
     self.num_bert_layers = len(
         self.model.encoder.layer) + 1  # +1 for embeddings
     self.output_layers = [
         i if i >= 0 else (self.num_bert_layers + i)
         for i in bconf.bert2_output_layers
     ]
     self.layer_is_output = [False] * self.num_bert_layers
     for i in self.output_layers:
         self.layer_is_output[i] = True
     # the highest used layer
     self.output_max_layer = max(
         self.output_layers) if len(self.output_layers) > 0 else -1
     # from max-layer down
     self.trainable_layers = list(range(self.output_max_layer, -1,
                                        -1))[:bconf.bert2_trainable_layers]
     # the lowest trainable layer
     self.trainable_min_layer = min(self.trainable_layers) if len(
         self.trainable_layers) > 0 else (self.output_max_layer + 1)
     zlog(f"Build Berter2: {self}")
     # add parameters
     prefix_name = self.pc.nnc_name(self.name, True) + "/"
     for layer_idx in self.trainable_layers:
         if layer_idx == 0:  # add the embedding layer
             infix_name = "embed"
             named_params = self.pc.param_add_external(
                 prefix_name + infix_name, self.model.embeddings)
         else:
             # here we should use the original (-1) index
             infix_name = "enc" + str(layer_idx)
             named_params = self.pc.param_add_external(
                 prefix_name + infix_name,
                 self.model.encoder.layer[layer_idx - 1])
         # add to self.params
         for one_name, one_param in named_params:
             assert f"{infix_name}_{one_name}" not in self.params
             self.params[f"{infix_name}_{one_name}"] = one_param
     # for dropout/mask input
     self.random_sample_stream = Random.stream(Random.random_sample)
     # =====
     # for other inputs; todo(note): still, 0 means all-zero embedding
     self.other_embeds = [
         self.add_sub_node(
             "OE", Embedding(self.pc,
                             vsize,
                             self.hidden_size,
                             fix_row0=True))
         for vsize in bconf.bert2_other_input_vsizes
     ]
     # =====
     # for output
     if bconf.bert2_output_mode == "layered":
         self.output_f = lambda x: x
         self.output_dims = (
             self.hidden_size,
             len(self.output_layers),
         )
     elif bconf.bert2_output_mode == "concat":
         self.output_f = lambda x: x.view(BK.get_shape(x)[:-2] + [-1]
                                          )  # combine the last two dims
         self.output_dims = (self.hidden_size * len(self.output_layers), )
     elif bconf.bert2_output_mode == "weighted":
         self.output_f = self.add_sub_node(
             "wb", BertFeaturesWeightLayer(pc, len(self.output_layers)))
         self.output_dims = (self.hidden_size, )
     else:
         raise NotImplementedError(
             f"UNK mode for bert2 output: {bconf.bert2_output_mode}")
Beispiel #25
0
 def fb_on_batch(self,
                 annotated_insts: List[DocInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     assert self.train_constrain_evt_types is None, "Not implemented for training constrain_types"
     # -----
     if len(annotated_insts) == 0:
         return {}
     # -----
     ndoc, nsent = len(annotated_insts), 0
     lambda_mention_ef, lambda_mention_evt, lambda_arg, lambda_span = \
         self.lambda_mention_ef.value, self.lambda_mention_evt.value, self.lambda_arg.value, self.lambda_span.value
     # simply multi-task training, no explict interactions between them
     all_packs = self.bter.run(annotated_insts, training=training)
     all_ef_mention_losses = []
     all_evt_mention_losses = []
     all_arg_losses = []
     all_span_losses = []
     #
     mix_pred_ef_rate = self.conf.mix_pred_ef_rate
     mix_pred_ef = self.conf.mix_pred_ef
     mix_pred_ef_count = 0
     # =====
     cur_margin = self.margin.value
     for one_pack in all_packs:
         ms_items, bert_expr, basic_expr = one_pack
         nsent += len(ms_items)
         if lambda_mention_ef > 0.:
             if mix_pred_ef:
                 # clear previous added fake ones
                 for one_msent in ms_items:
                     center_sent = one_msent.sents[one_msent.center_idx]
                     center_sent.entity_fillers = [
                         z for z in center_sent.entity_fillers
                         if not hasattr(z, "is_mix")
                     ]
                 # -----
             ef_losses = self.ef_extractor.loss(ms_items, bert_expr,
                                                basic_expr)
             all_ef_mention_losses.append(ef_losses)
         if lambda_mention_evt > 0.:
             evt_losses = self.evt_extractor.loss(ms_items,
                                                  bert_expr,
                                                  basic_expr,
                                                  margin=cur_margin)
             all_evt_mention_losses.append(evt_losses)
         if lambda_span > 0.:
             span_losses = self.span_expander.loss(ms_items, bert_expr)
             all_span_losses.append(span_losses)
         # predict efs as candidates for args
         if mix_pred_ef:
             with BK.no_grad_env():
                 self.ef_extractor.predict(ms_items, bert_expr, basic_expr)
                 # mix into gold ones
                 for one_msent in ms_items:
                     center_sent = one_msent.sents[one_msent.center_idx]
                     # since we might cache insts, we do not consider previous mixed ones
                     hit_posi = set()
                     center_sent.entity_fillers = [
                         z for z in center_sent.entity_fillers
                         if not hasattr(z, "is_mix")
                     ]
                     for one_ef in center_sent.entity_fillers:
                         posi = one_ef.mention.hard_span.position()
                         hit_posi.add(posi)
                     # add predicted ones
                     for one_ef in center_sent.pred_entity_fillers:
                         posi = one_ef.mention.hard_span.position()
                         if posi not in hit_posi and next(
                                 self.random_sample_stream
                         ) <= mix_pred_ef_rate:
                             hit_posi.add(posi)
                             one_ef.is_mix = True
                             # todo(note): these are not TRUE efs, but only mixing preds as neg examples for training
                             center_sent.entity_fillers.append(one_ef)
                             mix_pred_ef_count += 1
     # =====
     # in some mode, we may want to collect predicted efs
     for one_pack in all_packs:
         ms_items, bert_expr, basic_expr = one_pack
         if lambda_arg > 0.:
             arg_losses = self.arg_linker.loss(ms_items,
                                               bert_expr,
                                               basic_expr,
                                               dynamic_prepare=mix_pred_ef)
             all_arg_losses.append(arg_losses)
     # =====
     # final loss sum and backward
     info = {
         "doc": ndoc,
         "sent": nsent,
         "fb": 1,
         "mix_pef": mix_pred_ef_count
     }
     if len(all_packs) == 0:
         return info
     self.collect_loss_and_backward(["ef", "evt", "arg", "span"], [
         all_ef_mention_losses, all_evt_mention_losses, all_arg_losses,
         all_span_losses
     ], [lambda_mention_ef, lambda_mention_evt, lambda_arg, lambda_span],
                                    info, training, loss_factor)
     # =====
     # # for debug
     # zlog([d.dataset for d in annotated_insts])
     # zlog(info)
     # =====
     return info
Beispiel #26
0
 def inference_on_batch(self, insts: List[DocInstance], **kwargs):
     self.refresh_batch(False)
     test_constrain_evt_types = self.test_constrain_evt_types
     # -----
     if len(insts) == 0:
         return {}
     # -----
     ndoc, nsent = len(insts), 0
     iconf = self.conf.iconf
     with BK.no_grad_env():
         # splitting into buckets
         all_packs = self.bter.run(insts, training=False)
         for one_pack in all_packs:
             ms_items, bert_expr, basic_expr = one_pack
             nsent += len(ms_items)
             # ef
             if iconf.lookup_ef:
                 self.ef_extractor.lookup(ms_items)
             elif iconf.pred_ef:
                 self.ef_extractor.predict(ms_items, bert_expr, basic_expr)
             # evt
             if iconf.lookup_evt:
                 self.evt_extractor.lookup(
                     ms_items, constrain_types=test_constrain_evt_types)
             elif iconf.pred_evt:
                 self.evt_extractor.predict(
                     ms_items,
                     bert_expr,
                     basic_expr,
                     constrain_types=test_constrain_evt_types)
         # deal with arg after pred all!!
         if iconf.pred_arg:
             for one_pack in all_packs:
                 ms_items, bert_expr, basic_expr = one_pack
                 self.arg_linker.predict(ms_items, bert_expr, basic_expr)
                 if iconf.pred_span:
                     self.span_expander.predict(ms_items, bert_expr)
     # collect all stats
     num_ef, num_evt, num_arg = 0, 0, 0
     for one_doc in insts:
         for one_sent in one_doc.sents:
             num_ef += len(one_sent.pred_entity_fillers)
             num_evt += len(one_sent.pred_events)
             num_arg += sum(len(z.links) for z in one_sent.pred_events)
             if self.static_span_expander is not None:
                 assert not iconf.pred_span, "Not compatible of these two modes!"
                 for one_ef in one_sent.pred_entity_fillers:
                     # todo(note): expand phrase by rule
                     one_hard_span = one_ef.mention.hard_span
                     head_wid = one_hard_span.head_wid
                     one_hard_span.wid, one_hard_span.length = self.static_span_expander.expand_span(
                         head_wid, one_sent)
     info = {
         "doc": ndoc,
         "sent": nsent,
         "num_ef": num_ef,
         "num_evt": num_evt,
         "num_arg": num_arg
     }
     if iconf.decode_verbose:
         zlog(f"Decode one mini-batch: {info}")
     return info
Beispiel #27
0
 def loss(self, insts: List[ParseInstance], enc_repr, mask_arr, g1_pack):
     # todo(WARN): may need sg if using other loss functions
     # first-round search
     cur_margin = self.margin.value
     cur_cost0_weight = self.cost0_weight.value
     with BK.no_grad_env():
         ags = self.searcher.start(insts, self.hm_feature_getter0, enc_repr, mask_arr, g1_pack, margin=cur_margin)
         self.searcher.go(ags)
     # then forward and backward
     # collect only loss-related actions
     toks_all, sent_all = 0, len(insts)
     pieces_all, pieces_no_cost, pieces_serr, pieces_valid = 0, 0, 0, 0
     toks_valid = 0
     arc_valid_weights, label_valid_weights = 0., 0.
     action_list, arc_weight_list, label_weight_list, bidxes_list = [], [], [], []
     bidx = 0
     # =====
     score_getter = self.searcher.ender.plain_ranker
     oracler_ranker = self.searcher.ender.oracle_ranker
     # =====
     for one_inst, one_ag in zip(insts, ags):
         cur_size = len(one_inst)  # excluding ROOT
         toks_all += cur_size
         # for all the pieces
         for sp in one_ag.special_points:
             plain_finals, oracle_finals = sp.plain_finals, sp.oracle_finals
             best_plain = max(plain_finals, key=score_getter)
             best_oracle = max(plain_finals+oracle_finals, key=oracler_ranker)
             cost_plain, cost_oracle = best_plain.cost_accu, best_oracle.cost_accu
             score_plain, score_oracle = score_getter(best_plain), score_getter(best_oracle)
             # if cost_oracle > 0.:  # the gold one cannot be searched?
             #     sent_oracle_has_cost += 1
             # add them
             pieces_all += 1
             if cost_plain <= cost_oracle:
                 pieces_no_cost += 1
             elif score_plain < score_oracle:
                 pieces_serr += 1  # search error
             else:
                 pieces_valid += 1
                 plain_states, oracle_states = best_plain.get_path(), best_oracle.get_path()
                 toks_valid += len(plain_states)
                 # Loss = score(best_plain) - score(best_oracle)
                 cur_aw, cur_lw = self._add_lists(plain_states, oracle_states, action_list,
                                 arc_weight_list, label_weight_list, bidxes_list, bidx, cur_cost0_weight)
                 arc_valid_weights += cur_aw
                 label_valid_weights += cur_lw
         bidx += 1
     # collect the losses
     info = {"sent": sent_all, "tok": toks_all, "tok_valid": toks_valid, "vw_arc": arc_valid_weights, "vw_lab": label_valid_weights,
             "pieces_all": pieces_all, "pieces_no_cost": pieces_no_cost, "pieces_serr": pieces_serr, "pieces_valid": pieces_valid}
     if toks_valid == 0:
         return None, None, info
     final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores = \
         self._loss(enc_repr, action_list, arc_weight_list, label_weight_list, bidxes_list)
     # todo(+1): other indicators?
     info["loss_sum_arc"] = BK.get_value(final_arc_loss_sum).item()
     info["loss_sum_lab"] = BK.get_value(final_label_loss_sum).item()
     # how to div
     if self.loss_div_step:
         if self.loss_div_weights:
             cur_div_arc = max(arc_valid_weights, 1.)
             cur_div_lab = max(label_valid_weights, 1.)
         else:
             cur_div_arc = cur_div_lab = (toks_all if self.loss_div_fullbatch else toks_valid)
     else:
         # todo(warn): here use pieces rather than sentences
         cur_div_arc = cur_div_lab = (pieces_all if self.loss_div_fullbatch else pieces_valid)
     final_loss = final_arc_loss_sum/cur_div_arc + final_label_loss_sum/cur_div_lab
     return final_loss, (arc_scores, label_scores), info
Beispiel #28
0
 def __init__(self, pc: BK.ParamCollection, input_dim: int,
              conf: PlainLMNodeConf, inputter: Inputter):
     super().__init__(pc, conf, name="PLM")
     self.conf = conf
     self.inputter = inputter
     self.input_dim = input_dim
     self.split_input_blm = conf.split_input_blm
     # this step is performed at the embedder, thus still does not influence the inputter
     self.add_root_token = self.inputter.embedder.add_root_token
     # vocab and padder
     vpack = inputter.vpack
     vocab_word = vpack.get_voc("word")
     # models
     real_input_dim = input_dim // 2 if self.split_input_blm else input_dim
     if conf.hid_dim <= 0:  # no hidden layer
         self.l2r_hid_layer = self.r2l_hid_layer = None
         self.pred_input_dim = real_input_dim
     else:
         self.l2r_hid_layer = self.add_sub_node(
             "l2r_h",
             Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act))
         self.r2l_hid_layer = self.add_sub_node(
             "r2l_h",
             Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act))
         self.pred_input_dim = conf.hid_dim
     # todo(note): unk is the first one above real words
     self.pred_size = min(conf.max_pred_rank + 1, vocab_word.unk)
     if conf.tie_input_embeddings:
         zwarn("Tie all preds in plm with input embeddings!!")
         self.l2r_pred = self.r2l_pred = None
         self.inputter_embed_node = self.inputter.embedder.get_node("word")
     else:
         self.l2r_pred = self.add_sub_node(
             "l2r_p",
             Affine(pc,
                    self.pred_input_dim,
                    self.pred_size,
                    init_rop=NoDropRop()))
         if conf.tie_bidirect_pred:
             self.r2l_pred = self.l2r_pred
         else:
             self.r2l_pred = self.add_sub_node(
                 "r2l_p",
                 Affine(pc,
                        self.pred_input_dim,
                        self.pred_size,
                        init_rop=NoDropRop()))
         self.inputter_embed_node = None
         if conf.init_pred_from_pretrain:
             npvec = vpack.get_emb("word")
             if npvec is None:
                 zwarn(
                     "Pretrained vector not provided, skip init pred embeddings!!"
                 )
             else:
                 with BK.no_grad_env():
                     self.l2r_pred.ws[0].copy_(
                         BK.input_real(npvec[:self.pred_size].T))
                     self.r2l_pred.ws[0].copy_(
                         BK.input_real(npvec[:self.pred_size].T))
                 zlog(
                     f"Init pred embeddings from pretrained vectors (size={self.pred_size})."
                 )
Beispiel #29
0
 def loss(self, insts: List[ParseInstance], enc_expr, final_valid_expr,
          go1_pack, training: bool, margin: float):
     # first do decoding and related preparation
     with BK.no_grad_env():
         _, _, g_packs, p_packs = self.decode(insts, enc_expr,
                                              final_valid_expr, go1_pack,
                                              training, margin)
         # flatten the packs (remember to rebase the indexes)
         gold_pack = self._flatten_packs(g_packs)
         pred_pack = self._flatten_packs(p_packs)
         if self.filter_pruned:
             # filter out non-valid (pruned) edges, to avoid prune error
             mod_unpruned_mask, gold_mask = self.helper.get_unpruned_mask(
                 final_valid_expr, gold_pack)
             pred_mask = mod_unpruned_mask[
                 pred_pack[0], pred_pack[1]]  # filter by specific mod
             gold_pack = [(None if z is None else z[gold_mask])
                          for z in gold_pack]
             pred_pack = [(None if z is None else z[pred_mask])
                          for z in pred_pack]
     # calculate the scores for loss
     gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = gold_pack
     pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes = pred_pack
     gold_arc_score, gold_label_score_all = self._get_basic_score(
         enc_expr, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes,
         gold_gp_idxes)
     pred_arc_score, pred_label_score_all = self._get_basic_score(
         enc_expr, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes,
         pred_gp_idxes)
     # whether have labeled scores
     if self.system_labeled:
         gold_label_score = BK.gather_one_lastdim(
             gold_label_score_all, gold_lab_idxes).squeeze(-1)
         pred_label_score = BK.gather_one_lastdim(
             pred_label_score_all, pred_lab_idxes).squeeze(-1)
         ret_scores = (gold_arc_score, pred_arc_score, gold_label_score,
                       pred_label_score)
         pred_full_scores, gold_full_scores = pred_arc_score + pred_label_score, gold_arc_score + gold_label_score
     else:
         ret_scores = (gold_arc_score, pred_arc_score)
         pred_full_scores, gold_full_scores = pred_arc_score, gold_arc_score
     # hinge loss: filter-margin by loss*margin to be aware of search error
     if self.filter_margin:
         with BK.no_grad_env():
             mat_shape = BK.get_shape(enc_expr)[:2]  # [bs, slen]
             heads_gold = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                            gold_b_idxes, gold_m_idxes,
                                            gold_h_idxes)
             heads_pred = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                            pred_b_idxes, pred_m_idxes,
                                            pred_h_idxes)
             error_count = (heads_gold != heads_pred).float()
             if self.system_labeled:
                 labels_gold = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                                 gold_b_idxes, gold_m_idxes,
                                                 gold_lab_idxes)
                 labels_pred = self._get_tmp_mat(mat_shape, 0, BK.int64,
                                                 pred_b_idxes, pred_m_idxes,
                                                 pred_lab_idxes)
                 error_count += (labels_gold != labels_pred).float()
             scores_gold = self._get_tmp_mat(mat_shape, 0., BK.float32,
                                             gold_b_idxes, gold_m_idxes,
                                             gold_full_scores)
             scores_pred = self._get_tmp_mat(mat_shape, 0., BK.float32,
                                             pred_b_idxes, pred_m_idxes,
                                             pred_full_scores)
             # todo(note): here, a small 0.1 is to exclude zero error: anyway they will get zero gradient
             sent_mask = ((scores_gold.sum(-1) - scores_pred.sum(-1)) <=
                          (margin * error_count.sum(-1) + 0.1)).float()
             num_valid_sent = float(BK.get_value(sent_mask.sum()))
         final_loss_sum = (
             pred_full_scores * sent_mask[pred_b_idxes] -
             gold_full_scores * sent_mask[gold_b_idxes]).sum()
     else:
         num_valid_sent = len(insts)
         final_loss_sum = (pred_full_scores - gold_full_scores).sum()
     # prepare final loss
     # divide loss by what?
     num_sent = len(insts)
     num_valid_tok = sum(len(z) for z in insts)
     if self.loss_div_tok:
         final_loss = final_loss_sum / num_valid_tok
     else:
         final_loss = final_loss_sum / num_sent
     final_loss_sum_val = float(BK.get_value(final_loss_sum))
     info = {
         "sent": num_sent,
         "sent_valid": num_valid_sent,
         "tok": num_valid_tok,
         "loss_sum": final_loss_sum_val
     }
     return final_loss, ret_scores, info
Beispiel #30
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     with BK.no_grad_env():
         self.refresh_batch(False)
         # ===== calculate
         scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(
             insts, False)
         full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr,
                                               False, 0.)
         full_label_score = self._score_label_full(scoring_expr_pack,
                                                   mask_expr, False, 0.)
         # normalizing scores
         full_score = None
         final_exp_score = False  # whether to provide PROB by exp
         if self.norm_local and self.loss_prob:
             full_score = BK.log_softmax(full_arc_score,
                                         -1).unsqueeze(-1) + BK.log_softmax(
                                             full_label_score, -1)
             final_exp_score = True
         elif self.norm_hlocal and self.loss_prob:
             # normalize at m dimension, ignore each nodes's self-finish step.
             full_score = BK.log_softmax(full_arc_score,
                                         -2).unsqueeze(-1) + BK.log_softmax(
                                             full_label_score, -1)
         elif self.norm_single and self.loss_prob:
             if self.conf.iconf.dec_single_neg:
                 # todo(+2): add all-neg for prob explanation
                 full_arc_probs = BK.sigmoid(full_arc_score)
                 full_label_probs = BK.sigmoid(full_label_score)
                 fake_arc_scores = BK.log(full_arc_probs) - BK.log(
                     1. - full_arc_probs)
                 fake_label_scores = BK.log(full_label_probs) - BK.log(
                     1. - full_label_probs)
                 full_score = fake_arc_scores.unsqueeze(
                     -1) + fake_label_scores
             else:
                 full_score = BK.logsigmoid(full_arc_score).unsqueeze(
                     -1) + BK.logsigmoid(full_label_score)
                 final_exp_score = True
         else:
             full_score = full_arc_score.unsqueeze(-1) + full_label_score
         # decode
         mst_lengths = [len(z) + 1 for z in insts
                        ]  # +=1 to include ROOT for mst decoding
         mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode(
             full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32))
         if final_exp_score:
             mst_scores_arr = np.exp(mst_scores_arr)
         # jpos prediction (directly index, no converting as in parsing)
         jpos_preds_expr = jpos_pack[2]
         has_jpos_pred = jpos_preds_expr is not None
         jpos_preds_arr = BK.get_value(
             jpos_preds_expr) if has_jpos_pred else None
         # ===== assign
         info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)}
         mst_real_labels = self.pred2real_labels(mst_labels_arr)
         for one_idx, one_inst in enumerate(insts):
             cur_length = mst_lengths[one_idx]
             one_inst.pred_heads.set_vals(
                 mst_heads_arr[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 mst_real_labels[one_idx][:cur_length], self.label_vocab)
             one_inst.pred_par_scores.set_vals(
                 mst_scores_arr[one_idx][:cur_length])
             if has_jpos_pred:
                 one_inst.pred_poses.build_vals(
                     jpos_preds_arr[one_idx][:cur_length],
                     self.bter.pos_vocab)
         return info