def forward_features(self, ids_expr, mask_expr, typeids_expr, other_embed_exprs: List): bmodel = self.model bmodel_embedding = bmodel.embeddings bmodel_encoder = bmodel.encoder # prepare attention_mask = mask_expr token_type_ids = BK.zeros(BK.get_shape( ids_expr)).long() if typeids_expr is None else typeids_expr extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # extended_attention_mask = extended_attention_mask.to(dtype=next(bmodel.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # embeddings cur_layer = 0 if self.trainable_min_layer <= 0: last_output = bmodel_embedding(ids_expr, position_ids=None, token_type_ids=token_type_ids) else: with BK.no_grad_env(): last_output = bmodel_embedding(ids_expr, position_ids=None, token_type_ids=token_type_ids) # extra embeddings (this implies overall graident requirements!!) for one_eidx, one_embed in enumerate(self.other_embeds): last_output += one_embed( other_embed_exprs[one_eidx]) # [bs, slen, D] # ===== all_outputs = [] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 # todo(note): be careful about the indexes! # not-trainable encoders trainable_min_layer_idx = max(0, self.trainable_min_layer - 1) with BK.no_grad_env(): for layer_module in bmodel_encoder.layer[:trainable_min_layer_idx]: last_output = layer_module(last_output, extended_attention_mask, None)[0] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 # trainable encoders for layer_module in bmodel_encoder.layer[trainable_min_layer_idx:self. output_max_layer]: last_output = layer_module(last_output, extended_attention_mask, None)[0] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 assert cur_layer == self.output_max_layer + 1 # stack if len(all_outputs) == 1: ret_expr = all_outputs[0].unsqueeze(-2) else: ret_expr = BK.stack(all_outputs, -2) # [BS, SLEN, LAYER, D] final_ret_exp = self.output_f(ret_expr) return final_ret_exp
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): iconf = self.conf.iconf pconf = iconf.pruning_conf with BK.no_grad_env(): self.refresh_batch(False) if iconf.use_pruning: # todo(note): for the testing of pruning mode, use the scores instead if self.g1_use_aux_scores: valid_mask, arc_score, label_score, mask_expr, _ = G1Parser.score_and_prune( insts, self.num_label, pconf) else: valid_mask, arc_score, label_score, mask_expr, _ = self.prune_on_batch( insts, pconf) valid_mask_f = valid_mask.float() # [*, len, len] mask_value = Constants.REAL_PRAC_MIN full_score = arc_score.unsqueeze(-1) + label_score full_score += (mask_value * (1. - valid_mask_f)).unsqueeze(-1) info_pruning = G1Parser.collect_pruning_info( insts, valid_mask_f) jpos_pack = [None, None, None] else: input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) full_score = self.scorer_helper.score_full(enc_repr) info_pruning = None # ===== self._decode(insts, full_score, mask_expr, "g1") # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} if info_pruning is not None: info.update(info_pruning) return info
def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf, vpack: VocabPackage): super().__init__(pc, None, None) self.conf = conf # vocab and padder self.word_vocab = vpack.get_voc("word") self.padder = DataPadder( 2, pad_vals=self.word_vocab.pad, mask_range=2) # todo(note): <pad>-id is very large # models self.hid_layer = self.add_sub_node( "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act)) self.pred_layer = self.add_sub_node( "pred", Affine(pc, conf.hid_dim, conf.max_pred_rank + 1, init_rop=NoDropRop())) if conf.init_pred_from_pretrain: npvec = vpack.get_emb("word") if npvec is None: zwarn( "Pretrained vector not provided, skip init pred embeddings!!" ) else: with BK.no_grad_env(): self.pred_layer.ws[0].copy_( BK.input_real(npvec[:conf.max_pred_rank + 1].T)) zlog( f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})." )
def inference_on_batch(self, insts: List[DocInstance], **kwargs): self.refresh_batch(False) # ----- if len(insts) == 0: return {} # ----- # todo(note): first do shallow copy! for one_doc in insts: for one_sent in one_doc.sents: one_sent.pred_entity_fillers = [ z for z in one_sent.entity_fillers ] one_sent.pred_events = [ shallow_copy(z) for z in one_sent.events ] # ----- ndoc, nsent = len(insts), 0 iconf = self.conf.iconf with BK.no_grad_env(): # splitting into buckets all_packs = self.bter.run(insts, training=False) for one_pack in all_packs: ms_items, bert_expr, basic_expr = one_pack nsent += len(ms_items) self.predictor.predict(ms_items, bert_expr, basic_expr) info = { "doc": ndoc, "sent": nsent, "num_evt": sum(len(z.pred_events) for z in insts) } if iconf.decode_verbose: zlog(f"Decode one mini-batch: {info}") return info
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): with BK.no_grad_env(): self.refresh_batch(False) # encode enc_expr, mask_expr = self.enc.run( insts, False, input_word_mask_repl=None) # no masklm in testing # decode self.dec.predict(insts, enc_expr, mask_expr) # ===== # test for masklm input_word_mask_repl_arr, output_pred_mask_repl_arr, ouput_pred_idx_arr = self.masklm.prepare( insts, False) enc_expr2, mask_expr2 = self.enc.run( insts, False, input_word_mask_repl=input_word_mask_repl_arr) masklm_loss = self.masklm.loss(enc_expr2, output_pred_mask_repl_arr, ouput_pred_idx_arr) masklm_loss_val, masklm_loss_count, masklm_corr_count = [ z.item() for z in masklm_loss[0] ] # ===== info = { "sent": len(insts), "tok": sum(map(len, insts)), "masklm_loss_val": masklm_loss_val, "masklm_loss_count": masklm_loss_count, "masklm_corr_count": masklm_corr_count } return info
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): assert labeled with BK.no_grad_env(): scores_shape = BK.get_shape(scores_expr) maxlen = scores_shape[1] # mask out diag scores_expr += BK.diagflat( BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # combined last two dimension and Max over them combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1]) combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1) # back to real idxes last_size = scores_shape[-1] greedy_heads = combined_max_idxes // last_size greedy_labels = combined_max_idxes % last_size if ret_arr: mst_heads_arr, mst_labels_arr, mst_scores_arr = [ BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores) ] return mst_heads_arr, mst_labels_arr, mst_scores_arr else: return greedy_heads, greedy_labels, combine_max_scores
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): _LEQ_MAP = lambda x, m: [max(m, z) for z in x] # larger or equal map # with BK.no_grad_env(): self.refresh_batch(False) scoring_prep_f = lambda ones: self._prepare_score(ones, False)[0] get_best_state_f = lambda ag: sorted(ag.ends, key=lambda s: s.score_accu, reverse=True)[0] ags, info = self.inferencer.decode(insts, scoring_prep_f, False) # put the results inplaced for one_inst, one_ag in zip(insts, ags): best_state = get_best_state_f(one_ag) one_inst.pred_heads.set_vals(_LEQ_MAP(best_state.list_arc, 0)) # directly int-val for heads # todo(warn): already the correct labels, no need to transform # -- one_inst.pred_labels.build_vals(self.pred2real_labels(best_state.list_label), self.label_vocab) one_inst.pred_labels.build_vals(_LEQ_MAP(best_state.list_label, 1), self.label_vocab) # todo(warn): add children ordering in MISC field one_inst.pred_miscs.set_vals(self._get_chs_ordering(best_state)) # check search err if self.conf.iconf.check_serr: ags_fo, _ = self.fber.force_decode(insts, scoring_prep_f, False) serr_sent = 0 serr_tok = 0 for ag, ag_fo in zip(ags, ags_fo): best_state = get_best_state_f(ag) best_fo_state = get_best_state_f(ag_fo) # for this one, only care about UAS if best_state.score_accu < best_fo_state.score_accu: cur_serr_tok = sum((1 if a!=b else 0) for a,b in zip(best_state.list_arc[1:], best_fo_state.list_arc[1:])) if cur_serr_tok > 0: serr_sent += 1 serr_tok += cur_serr_tok info["serr_sent"] = serr_sent info["serr_tok"] = serr_tok return info
def score_on_batch(self, insts: List[ParseInstance]): with BK.no_grad_env(): self.refresh_batch(False) input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) # mask_expr = BK.input_real(mask_arr) arc_score = self.scorer_helper.score_arc(enc_repr) label_score = self.scorer_helper.score_label(enc_repr) return arc_score.squeeze(-1), label_score
def score_on_batch(self, insts: List[ParseInstance]): with BK.no_grad_env(): self.refresh_batch(False) scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( insts, False) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) return full_arc_score, full_label_score
def inference_on_batch(self, insts: List[GeneralSentence], **kwargs): conf = self.conf self.refresh_batch(False) # print(f"{len(insts)}: {insts[0].sid}") with BK.no_grad_env(): # decode for dpar input_map = self.model.inputter(insts) emb_t, mask_t, enc_t, cache, _ = self.model._emb_and_enc(input_map, collect_loss=False) input_t = BK.concat(cache.list_attn, -1) # [bs, slen, slen, L*H] self.dpar.predict(insts, BK.zeros([1,1]), input_t, mask_t) return {}
def score_and_select(self, enc_expr, pad_mask): conf = self.conf if conf.ns_no_back: enc_expr = enc_expr.detach() all_scores = self.scorer(enc_expr).squeeze(-1) # [*, slen] # only for getting the mask with BK.no_grad_env(): masked_all_scores = all_scores + ( 1. - pad_mask) * Constants.REAL_PRAC_MIN res_mask = self._select_topk(masked_all_scores, pad_mask, pad_mask, conf.topk_ratio, conf.thresh_k) return res_mask, all_scores
def inference_on_batch(self, insts: List[GeneralSentence], **kwargs): conf = self.conf self.refresh_batch(False) with BK.no_grad_env(): # special mode # use: CUDA_VISIBLE_DEVICES=3 PYTHONPATH=../../src/ python3 -m pdb ../../src/tasks/cmd.py zmlm.main.test ${RUN_DIR}/_conf device:0 dict_dir:${RUN_DIR}/ model_load_name:${RUN_DIR}/zmodel.best test:./_en.debug test_interactive:1 if conf.test_interactive: iinput_sent = input(">> (Interactive testing) Input sent sep by blanks: ") iinput_tokens = iinput_sent.split() if len(iinput_sent) > 0: iinput_inst = GeneralSentence.create(iinput_tokens) iinput_inst.word_seq.set_idxes([self.word_vocab.get_else_unk(w) for w in iinput_inst.word_seq.vals]) iinput_inst.char_seq.build_idxes(self.inputter.vpack.get_voc("char")) iinput_map = self.inputter([iinput_inst]) iinput_erase_mask = np.asarray([[z=="Z" for z in iinput_tokens]]).astype(dtype=np.float32) iinput_masked_map = self.inputter.mask_input(iinput_map, iinput_erase_mask, set("pos")) emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(iinput_masked_map, collect_loss=False, insts=[iinput_inst]) mlm_loss = self.masklm.loss(enc_t, iinput_erase_mask, iinput_map) dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1))) self.dpar.predict([iinput_inst], enc_t, dpar_input_attn, mask_t) self.upos.predict([iinput_inst], enc_t, mask_t) # print them import pandas as pd cur_fields = { "idxes": list(range(1, len(iinput_inst)+1)), "word": iinput_inst.word_seq.vals, "pos": iinput_inst.pred_pos_seq.vals, "head": iinput_inst.pred_dep_tree.heads[1:], "dlab": iinput_inst.pred_dep_tree.labels[1:]} zlog(f"Result:\n{pd.DataFrame(cur_fields).to_string()}") return {} # simply return here for interactive mode # ----- # test for MLM simply as in training (use special separate rand_gen to keep the masks the same for testing) # todo(+2): do we need to keep testing/validing during training the same? Currently not! info = self.fb_on_batch(insts, training=False, rand_gen=self.testing_rand_gen, assign_attns=conf.testing_get_attns) # ----- if len(insts) == 0: return info # decode for dpar input_map = self.inputter(insts) emb_t, mask_t, enc_t, cache, _ = self._emb_and_enc(input_map, collect_loss=False, insts=insts) dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1))) self.dpar.predict(insts, enc_t, dpar_input_attn, mask_t) self.upos.predict(insts, enc_t, mask_t) if self.ner is not None: self.ner.predict(insts, enc_t, mask_t) # ----- if conf.testing_get_attns: if conf.enc_choice == "vrec": self._assign_attns_item(insts, "orig", cache=cache) elif conf.enc_choice in ["original"]: pass else: raise NotImplementedError() return info
def inference_on_batch(self, insts: List[DocInstance], **kwargs): self.refresh_batch(False) # ----- if len(insts) == 0: return {} # ----- ndoc, nsent = len(insts), 0 iconf = self.conf.iconf # ===== # get tmp ms_items for each event input_ms_items = self._insts2msitems(insts) # ----- if len(input_ms_items) == 0: return {} # ----- with BK.no_grad_env(): # splitting into buckets all_packs = self.bter.run(input_ms_items, training=False) for one_pack in all_packs: ms_items, bert_expr, basic_expr = one_pack nsent += len(ms_items) # cands if iconf.lookup_ef: self._lookup_efs(ms_items) else: self.cand_extractor.predict(ms_items, bert_expr, basic_expr) # args if iconf.pred_arg: self.arg_linker.predict(ms_items, bert_expr, basic_expr) # span if iconf.pred_span: self.span_expander.predict(ms_items, bert_expr) # put back all predictions self._putback_preds(input_ms_items) # collect all stats num_ef, num_evt, num_arg = 0, 0, 0 for one_doc in insts: for one_sent in one_doc.sents: num_ef += len(one_sent.pred_entity_fillers) num_evt += len(one_sent.pred_events) num_arg += sum(len(z.links) for z in one_sent.pred_events) info = { "doc": ndoc, "sent": nsent, "num_ef": num_ef, "num_evt": num_evt, "num_arg": num_arg } if iconf.decode_verbose: zlog(f"Decode one mini-batch: {info}") return info
def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1., rand_gen=None, assign_attns=False, **kwargs): self.refresh_batch(training) # get inputs with models with BK.no_grad_env(): input_map = self.model.inputter(insts) emb_t, mask_t, enc_t, cache, enc_loss = self.model._emb_and_enc(input_map, collect_loss=True) input_t = BK.concat(cache.list_attn, -1) # [bs, slen, slen, L*H] losses = [self.dpar.loss(insts, BK.zeros([1,1]), input_t, mask_t)] # ----- info = self.collect_loss_and_backward(losses, training, loss_factor) info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)}) return info
def prune_on_batch(self, insts: List[ParseInstance], pconf: PruneG1Conf): with BK.no_grad_env(): self.refresh_batch(False) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) arc_score = self.scorer_helper.score_arc(enc_repr) label_score = self.scorer_helper.score_label(enc_repr) final_valid_mask, arc_marginals = G1Parser.prune_with_scores( arc_score, label_score, mask_expr, pconf) return final_valid_mask, arc_score.squeeze( -1), label_score, mask_expr, arc_marginals
def loss(self, enc_expr, pad_mask, gold_mask, margin: float): conf = self.conf # ===== # first testing-mode scoring and selecting res_mask, all_scores = self.score_and_select(enc_expr, pad_mask) # add gold if conf.ns_add_gold: res_mask += gold_mask res_mask.clamp_(max=1.) # ===== with BK.no_grad_env(): # how to select instances for training if conf.train_ratio2gold > 0.: # use gold-ratio for training masked_all_scores = all_scores + ( 1. - pad_mask + gold_mask) * Constants.REAL_PRAC_MIN loss_mask = self._select_topk(masked_all_scores, pad_mask, gold_mask, conf.train_ratio2gold, None) loss_mask += gold_mask loss_mask.clamp_(max=1.) elif not conf.ns_add_gold: loss_mask = res_mask + gold_mask loss_mask.clamp_(max=1.) else: # we already have the gold loss_mask = res_mask # ===== calculating losses [*, L] # first aug scores by margin aug_scores = all_scores - (conf.margin_pos * margin) * gold_mask + ( conf.margin_neg * margin) * (1. - gold_mask) if self.loss_hinge: # multiply pos instances with -1 flipped_scores = aug_scores * (1. - 2 * gold_mask) losses_all = BK.clamp(flipped_scores, min=0.) elif self.loss_prob: losses_all = BK.binary_cross_entropy_with_logits(aug_scores, gold_mask, reduction='none') if conf.no_loss_satisfy_margin: unsatisfy_mask = ((aug_scores * (1. - 2 * gold_mask)) > 0.).float() # those still with hinge loss losses_all *= unsatisfy_mask else: raise NotImplementedError() # return prediction and loss(sum/count) loss_sum = (losses_all * loss_mask).sum() if conf.train_return_loss_mask: return [[loss_sum, loss_mask.sum()]], loss_mask else: return [[loss_sum, loss_mask.sum()]], res_mask
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): with BK.no_grad_env(): self.refresh_batch(False) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, False) # g1 score g1_pack = self._get_g1_pack(insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) # decode for parsing self.inferencer.decode(insts, enc_repr, mask_arr, g1_pack, self.label_vocab) # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} return info
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): # iconf = self.conf.iconf with BK.no_grad_env(): self.refresh_batch(False) full_score, _, jpos_pack, mask_expr, _, _ = \ self._score(insts, False, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) # collect the results together # ===== self._decode(insts, full_score, mask_expr, "s2") # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} return info
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr, dim=-1) # [BS, m, h] # marginal for unlabeled scores_unlabeled_arr = BK.get_value(scores_unlabeled) marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr, lengths_arr, False) # back to labeled values marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr) marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze( -1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1)) # [BS, m, h, L] return _ensure_margins_norm(marginals_labeled_expr)
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): # iconf = self.conf.iconf with BK.no_grad_env(): self.refresh_batch(False) # pruning and scores from g1 valid_mask, go1_pack = self._get_g1_pack( insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) # decode final_valid_expr = self._make_final_valid(valid_mask, mask_expr) ret_heads, ret_labels, _, _ = self.dl.decode( insts, enc_repr, final_valid_expr, go1_pack, False, 0.) # collect the results together all_heads = Helper.join_list(ret_heads) if ret_labels is None: # todo(note): simply get labels from the go1-label classifier; must provide g1parser if go1_pack is None: _, go1_pack = self._get_g1_pack(insts, 1., 1.) _, go1_label_max_idxes = go1_pack[1].max( -1) # [bs, slen, slen] pred_heads_arr, _ = self.predict_padder.pad( all_heads) # [bs, slen] pred_heads_expr = BK.input_idx(pred_heads_arr) pred_labels_expr = BK.gather_one_lastdim( go1_label_max_idxes, pred_heads_expr).squeeze(-1) all_labels = BK.get_value(pred_labels_expr) # [bs, slen] else: all_labels = np.concatenate(ret_labels, 0) # ===== assign, todo(warn): here, the labels are directly original idx, no need to change for one_idx, one_inst in enumerate(insts): cur_length = len(one_inst) + 1 one_inst.pred_heads.set_vals( all_heads[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( all_labels[one_idx][:cur_length], self.label_vocab) # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length]) # ===== # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} return info
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): assert labeled with BK.no_grad_env(): # argmax-label: [BS, m, h] scores_unlabeled_max, labels_argmax = scores_expr.max(-1) # scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max) mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr, lengths_arr, labeled=False) # [BS, m] mst_heads_expr = BK.input_idx(mst_heads_arr) mst_labels_expr = BK.gather_one_lastdim(labels_argmax, mst_heads_expr).squeeze(-1) # prepare for the outputs if ret_arr: return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr else: return mst_heads_expr, mst_labels_expr, BK.input_real( mst_scores_arr)
def inference_on_batch(self, insts: List[DocInstance], **kwargs): self.refresh_batch(False) test_constrain_evt_types = self.test_constrain_evt_types ndoc, nsent = len(insts), 0 iconf = self.conf.iconf with BK.no_grad_env(): # splitting into buckets all_packs = self.bter.run(insts, training=False) for one_pack in all_packs: # ===== # predict sent_insts, lexi_repr, enc_repr_ef, enc_repr_evt, mask_arr = one_pack nsent += len(sent_insts) mask_expr = BK.input_real(mask_arr) # entity and filler if iconf.lookup_ef: ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \ self._lookup_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, self.ef_extractor, ret_copy=True) elif iconf.pred_ef: ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \ self._inference_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, self.ef_extractor, self.ef_creator) else: ef_items = [[] for _ in range(len(sent_insts))] ef_valid_mask = BK.zeros((len(sent_insts), 0)) ef_widxes = ef_lab_idxes = ef_lab_embeds = None # event if iconf.lookup_evt: evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \ self._lookup_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, self.evt_extractor, ret_copy=True) else: evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \ self._inference_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, self.evt_extractor, self.evt_creator) # arg if iconf.pred_arg: # todo(note): for this step of decoding, we only consider inner-sentence pairs # todo(note): inplaced self._inference_args(ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt) # ===== # assign for one_sent_inst, one_ef_items, one_ef_valid, one_evt_items, one_evt_valid in \ zip(sent_insts, ef_items, BK.get_value(ef_valid_mask), evt_items, BK.get_value(evt_valid_mask)): # entity and filler one_ef_items = [z for z,va in zip(one_ef_items, one_ef_valid) if (va and z is not None)] one_sent_inst.pred_entity_fillers = one_ef_items # event one_evt_items = [z for z,va in zip(one_evt_items, one_evt_valid) if (va and z is not None)] if test_constrain_evt_types is not None: one_evt_items = [z for z in one_evt_items if z.type in test_constrain_evt_types] # ===== # todo(note): special rule (actually a simple rule based extender) if iconf.expand_evt_compound: for one_evt in one_evt_items: one_hard_span = one_evt.mention.hard_span sid, hwid, _ = one_hard_span.position(True) assert one_hard_span.length == 1 # currently no way to predict more if hwid+1 < one_sent_inst.length: if one_sent_inst.uposes.vals[hwid]=="VERB" and one_sent_inst.ud_heads.vals[hwid+1]==hwid \ and one_sent_inst.ud_labels.vals[hwid+1]=="compound": one_hard_span.length += 1 # ===== one_sent_inst.pred_events = one_evt_items return {"doc": ndoc, "sent": nsent}
def fb_on_batch(self, annotated_insts: List[DocInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) assert self.train_constrain_evt_types is None, "Not implemented for training constrain_types" ndoc, nsent = len(annotated_insts), 0 lambda_cand, lambda_arg, lambda_span = self.lambda_cand.value, self.lambda_arg.value, self.lambda_span.value # simply multi-task training, no explict interactions between them # ----- # only include the ones with enough annotations valid_insts = [] for one_inst in annotated_insts: if all(z.events is not None and z.entity_fillers is not None for z in one_inst.sents): valid_insts.append(one_inst) input_ms_items = self._insts2msitems(valid_insts) if len(input_ms_items) == 0: return {} # ----- all_packs = self.bter.run(input_ms_items, training=training) all_cand_losses, all_arg_losses, all_span_losses = [], [], [] # ===== mix_pred_ef_rate = self.conf.mix_pred_ef_rate mix_pred_ef_rate_outside = self.conf.mix_pred_ef_rate_outside mix_pred_ef_count = 0 # cur_margin = self.margin.value # todo(+N): currently not used! for one_pack in all_packs: ms_items, bert_expr, basic_expr = one_pack nsent += len(ms_items) if lambda_cand > 0.: # todo(note): no need to clean up pred ones since sentences and containers are copied cand_losses = self.cand_extractor.loss(ms_items, bert_expr, basic_expr) all_cand_losses.append(cand_losses) # predict as candidates with BK.no_grad_env(): self.cand_extractor.predict(ms_items, bert_expr, basic_expr) # mix into gold ones; no need to cleanup since these are copies by _insts2msitems for one_msent in ms_items: center_idx = one_msent.center_idx for one_sidx, one_sent in enumerate(one_msent.sents): hit_posi = set() for one_ef in one_sent.entity_fillers: posi = one_ef.mention.hard_span.position() hit_posi.add(posi) # add predicted ones cur_mix_rate = mix_pred_ef_rate if ( center_idx == one_sidx) else mix_pred_ef_rate_outside for one_ef in one_sent.pred_entity_fillers: posi = one_ef.mention.hard_span.position() if posi not in hit_posi and next( self.random_sample_stream ) <= cur_mix_rate: hit_posi.add(posi) # one_ef.is_mix = True # todo(note): these are not TRUE efs, but only mixing preds as neg examples for training one_sent.entity_fillers.append(one_ef) mix_pred_ef_count += 1 if lambda_arg > 0.: # todo(note): since currently we are predicting all candidates for one event arg_losses = self.arg_linker.loss(ms_items, bert_expr, basic_expr, dynamic_prepare=True) all_arg_losses.append(arg_losses) if lambda_span > 0.: span_losses = self.span_expander.loss(ms_items, bert_expr) all_span_losses.append(span_losses) # ===== # final loss sum and backward info = { "doc": ndoc, "sent": nsent, "fb": 1, "mix_pef": mix_pred_ef_count } if len(all_packs) > 0: self.collect_loss_and_backward( ["cand", "arg", "span"], [all_cand_losses, all_arg_losses, all_span_losses], [lambda_cand, lambda_arg, lambda_span], info, training, loss_factor) return info
def __init__(self, pc: BK.ParamCollection, bconf: Berter2Conf): super().__init__(pc, None, None) self.bconf = bconf self.model_name = bconf.bert2_model zlog( f"Loading pre-trained bert model for Berter2 of {self.model_name}") # Load pretrained model/tokenizer self.tokenizer = BertTokenizer.from_pretrained( self.model_name, do_lower_case=bconf.bert2_lower_case, cache_dir=None if (not bconf.bert2_cache_dir) else bconf.bert2_cache_dir) self.model = BertModel.from_pretrained( self.model_name, output_hidden_states=True, cache_dir=None if (not bconf.bert2_cache_dir) else bconf.bert2_cache_dir) zlog(f"Load done, move to default device {BK.DEFAULT_DEVICE}") BK.to_device(self.model) # ===== # zero padding embeddings? if bconf.bert2_zero_pademb: with BK.no_grad_env(): # todo(warn): specific!! zlog( f"Unusual operation: make bert's padding embedding (idx0) zero!!" ) self.model.embeddings.word_embeddings.weight[0].fill_(0.) # ===== # check trainable ones and add parameters # todo(+N): this part is specific and looking into the lib, can break in further versions!! # the idx of layer is [1(embed)] + [N(enc)], that is, layer0 is the output of embeddings self.hidden_size = self.model.config.hidden_size self.num_bert_layers = len( self.model.encoder.layer) + 1 # +1 for embeddings self.output_layers = [ i if i >= 0 else (self.num_bert_layers + i) for i in bconf.bert2_output_layers ] self.layer_is_output = [False] * self.num_bert_layers for i in self.output_layers: self.layer_is_output[i] = True # the highest used layer self.output_max_layer = max( self.output_layers) if len(self.output_layers) > 0 else -1 # from max-layer down self.trainable_layers = list(range(self.output_max_layer, -1, -1))[:bconf.bert2_trainable_layers] # the lowest trainable layer self.trainable_min_layer = min(self.trainable_layers) if len( self.trainable_layers) > 0 else (self.output_max_layer + 1) zlog(f"Build Berter2: {self}") # add parameters prefix_name = self.pc.nnc_name(self.name, True) + "/" for layer_idx in self.trainable_layers: if layer_idx == 0: # add the embedding layer infix_name = "embed" named_params = self.pc.param_add_external( prefix_name + infix_name, self.model.embeddings) else: # here we should use the original (-1) index infix_name = "enc" + str(layer_idx) named_params = self.pc.param_add_external( prefix_name + infix_name, self.model.encoder.layer[layer_idx - 1]) # add to self.params for one_name, one_param in named_params: assert f"{infix_name}_{one_name}" not in self.params self.params[f"{infix_name}_{one_name}"] = one_param # for dropout/mask input self.random_sample_stream = Random.stream(Random.random_sample) # ===== # for other inputs; todo(note): still, 0 means all-zero embedding self.other_embeds = [ self.add_sub_node( "OE", Embedding(self.pc, vsize, self.hidden_size, fix_row0=True)) for vsize in bconf.bert2_other_input_vsizes ] # ===== # for output if bconf.bert2_output_mode == "layered": self.output_f = lambda x: x self.output_dims = ( self.hidden_size, len(self.output_layers), ) elif bconf.bert2_output_mode == "concat": self.output_f = lambda x: x.view(BK.get_shape(x)[:-2] + [-1] ) # combine the last two dims self.output_dims = (self.hidden_size * len(self.output_layers), ) elif bconf.bert2_output_mode == "weighted": self.output_f = self.add_sub_node( "wb", BertFeaturesWeightLayer(pc, len(self.output_layers))) self.output_dims = (self.hidden_size, ) else: raise NotImplementedError( f"UNK mode for bert2 output: {bconf.bert2_output_mode}")
def fb_on_batch(self, annotated_insts: List[DocInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) assert self.train_constrain_evt_types is None, "Not implemented for training constrain_types" # ----- if len(annotated_insts) == 0: return {} # ----- ndoc, nsent = len(annotated_insts), 0 lambda_mention_ef, lambda_mention_evt, lambda_arg, lambda_span = \ self.lambda_mention_ef.value, self.lambda_mention_evt.value, self.lambda_arg.value, self.lambda_span.value # simply multi-task training, no explict interactions between them all_packs = self.bter.run(annotated_insts, training=training) all_ef_mention_losses = [] all_evt_mention_losses = [] all_arg_losses = [] all_span_losses = [] # mix_pred_ef_rate = self.conf.mix_pred_ef_rate mix_pred_ef = self.conf.mix_pred_ef mix_pred_ef_count = 0 # ===== cur_margin = self.margin.value for one_pack in all_packs: ms_items, bert_expr, basic_expr = one_pack nsent += len(ms_items) if lambda_mention_ef > 0.: if mix_pred_ef: # clear previous added fake ones for one_msent in ms_items: center_sent = one_msent.sents[one_msent.center_idx] center_sent.entity_fillers = [ z for z in center_sent.entity_fillers if not hasattr(z, "is_mix") ] # ----- ef_losses = self.ef_extractor.loss(ms_items, bert_expr, basic_expr) all_ef_mention_losses.append(ef_losses) if lambda_mention_evt > 0.: evt_losses = self.evt_extractor.loss(ms_items, bert_expr, basic_expr, margin=cur_margin) all_evt_mention_losses.append(evt_losses) if lambda_span > 0.: span_losses = self.span_expander.loss(ms_items, bert_expr) all_span_losses.append(span_losses) # predict efs as candidates for args if mix_pred_ef: with BK.no_grad_env(): self.ef_extractor.predict(ms_items, bert_expr, basic_expr) # mix into gold ones for one_msent in ms_items: center_sent = one_msent.sents[one_msent.center_idx] # since we might cache insts, we do not consider previous mixed ones hit_posi = set() center_sent.entity_fillers = [ z for z in center_sent.entity_fillers if not hasattr(z, "is_mix") ] for one_ef in center_sent.entity_fillers: posi = one_ef.mention.hard_span.position() hit_posi.add(posi) # add predicted ones for one_ef in center_sent.pred_entity_fillers: posi = one_ef.mention.hard_span.position() if posi not in hit_posi and next( self.random_sample_stream ) <= mix_pred_ef_rate: hit_posi.add(posi) one_ef.is_mix = True # todo(note): these are not TRUE efs, but only mixing preds as neg examples for training center_sent.entity_fillers.append(one_ef) mix_pred_ef_count += 1 # ===== # in some mode, we may want to collect predicted efs for one_pack in all_packs: ms_items, bert_expr, basic_expr = one_pack if lambda_arg > 0.: arg_losses = self.arg_linker.loss(ms_items, bert_expr, basic_expr, dynamic_prepare=mix_pred_ef) all_arg_losses.append(arg_losses) # ===== # final loss sum and backward info = { "doc": ndoc, "sent": nsent, "fb": 1, "mix_pef": mix_pred_ef_count } if len(all_packs) == 0: return info self.collect_loss_and_backward(["ef", "evt", "arg", "span"], [ all_ef_mention_losses, all_evt_mention_losses, all_arg_losses, all_span_losses ], [lambda_mention_ef, lambda_mention_evt, lambda_arg, lambda_span], info, training, loss_factor) # ===== # # for debug # zlog([d.dataset for d in annotated_insts]) # zlog(info) # ===== return info
def inference_on_batch(self, insts: List[DocInstance], **kwargs): self.refresh_batch(False) test_constrain_evt_types = self.test_constrain_evt_types # ----- if len(insts) == 0: return {} # ----- ndoc, nsent = len(insts), 0 iconf = self.conf.iconf with BK.no_grad_env(): # splitting into buckets all_packs = self.bter.run(insts, training=False) for one_pack in all_packs: ms_items, bert_expr, basic_expr = one_pack nsent += len(ms_items) # ef if iconf.lookup_ef: self.ef_extractor.lookup(ms_items) elif iconf.pred_ef: self.ef_extractor.predict(ms_items, bert_expr, basic_expr) # evt if iconf.lookup_evt: self.evt_extractor.lookup( ms_items, constrain_types=test_constrain_evt_types) elif iconf.pred_evt: self.evt_extractor.predict( ms_items, bert_expr, basic_expr, constrain_types=test_constrain_evt_types) # deal with arg after pred all!! if iconf.pred_arg: for one_pack in all_packs: ms_items, bert_expr, basic_expr = one_pack self.arg_linker.predict(ms_items, bert_expr, basic_expr) if iconf.pred_span: self.span_expander.predict(ms_items, bert_expr) # collect all stats num_ef, num_evt, num_arg = 0, 0, 0 for one_doc in insts: for one_sent in one_doc.sents: num_ef += len(one_sent.pred_entity_fillers) num_evt += len(one_sent.pred_events) num_arg += sum(len(z.links) for z in one_sent.pred_events) if self.static_span_expander is not None: assert not iconf.pred_span, "Not compatible of these two modes!" for one_ef in one_sent.pred_entity_fillers: # todo(note): expand phrase by rule one_hard_span = one_ef.mention.hard_span head_wid = one_hard_span.head_wid one_hard_span.wid, one_hard_span.length = self.static_span_expander.expand_span( head_wid, one_sent) info = { "doc": ndoc, "sent": nsent, "num_ef": num_ef, "num_evt": num_evt, "num_arg": num_arg } if iconf.decode_verbose: zlog(f"Decode one mini-batch: {info}") return info
def loss(self, insts: List[ParseInstance], enc_repr, mask_arr, g1_pack): # todo(WARN): may need sg if using other loss functions # first-round search cur_margin = self.margin.value cur_cost0_weight = self.cost0_weight.value with BK.no_grad_env(): ags = self.searcher.start(insts, self.hm_feature_getter0, enc_repr, mask_arr, g1_pack, margin=cur_margin) self.searcher.go(ags) # then forward and backward # collect only loss-related actions toks_all, sent_all = 0, len(insts) pieces_all, pieces_no_cost, pieces_serr, pieces_valid = 0, 0, 0, 0 toks_valid = 0 arc_valid_weights, label_valid_weights = 0., 0. action_list, arc_weight_list, label_weight_list, bidxes_list = [], [], [], [] bidx = 0 # ===== score_getter = self.searcher.ender.plain_ranker oracler_ranker = self.searcher.ender.oracle_ranker # ===== for one_inst, one_ag in zip(insts, ags): cur_size = len(one_inst) # excluding ROOT toks_all += cur_size # for all the pieces for sp in one_ag.special_points: plain_finals, oracle_finals = sp.plain_finals, sp.oracle_finals best_plain = max(plain_finals, key=score_getter) best_oracle = max(plain_finals+oracle_finals, key=oracler_ranker) cost_plain, cost_oracle = best_plain.cost_accu, best_oracle.cost_accu score_plain, score_oracle = score_getter(best_plain), score_getter(best_oracle) # if cost_oracle > 0.: # the gold one cannot be searched? # sent_oracle_has_cost += 1 # add them pieces_all += 1 if cost_plain <= cost_oracle: pieces_no_cost += 1 elif score_plain < score_oracle: pieces_serr += 1 # search error else: pieces_valid += 1 plain_states, oracle_states = best_plain.get_path(), best_oracle.get_path() toks_valid += len(plain_states) # Loss = score(best_plain) - score(best_oracle) cur_aw, cur_lw = self._add_lists(plain_states, oracle_states, action_list, arc_weight_list, label_weight_list, bidxes_list, bidx, cur_cost0_weight) arc_valid_weights += cur_aw label_valid_weights += cur_lw bidx += 1 # collect the losses info = {"sent": sent_all, "tok": toks_all, "tok_valid": toks_valid, "vw_arc": arc_valid_weights, "vw_lab": label_valid_weights, "pieces_all": pieces_all, "pieces_no_cost": pieces_no_cost, "pieces_serr": pieces_serr, "pieces_valid": pieces_valid} if toks_valid == 0: return None, None, info final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores = \ self._loss(enc_repr, action_list, arc_weight_list, label_weight_list, bidxes_list) # todo(+1): other indicators? info["loss_sum_arc"] = BK.get_value(final_arc_loss_sum).item() info["loss_sum_lab"] = BK.get_value(final_label_loss_sum).item() # how to div if self.loss_div_step: if self.loss_div_weights: cur_div_arc = max(arc_valid_weights, 1.) cur_div_lab = max(label_valid_weights, 1.) else: cur_div_arc = cur_div_lab = (toks_all if self.loss_div_fullbatch else toks_valid) else: # todo(warn): here use pieces rather than sentences cur_div_arc = cur_div_lab = (pieces_all if self.loss_div_fullbatch else pieces_valid) final_loss = final_arc_loss_sum/cur_div_arc + final_label_loss_sum/cur_div_lab return final_loss, (arc_scores, label_scores), info
def __init__(self, pc: BK.ParamCollection, input_dim: int, conf: PlainLMNodeConf, inputter: Inputter): super().__init__(pc, conf, name="PLM") self.conf = conf self.inputter = inputter self.input_dim = input_dim self.split_input_blm = conf.split_input_blm # this step is performed at the embedder, thus still does not influence the inputter self.add_root_token = self.inputter.embedder.add_root_token # vocab and padder vpack = inputter.vpack vocab_word = vpack.get_voc("word") # models real_input_dim = input_dim // 2 if self.split_input_blm else input_dim if conf.hid_dim <= 0: # no hidden layer self.l2r_hid_layer = self.r2l_hid_layer = None self.pred_input_dim = real_input_dim else: self.l2r_hid_layer = self.add_sub_node( "l2r_h", Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act)) self.r2l_hid_layer = self.add_sub_node( "r2l_h", Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act)) self.pred_input_dim = conf.hid_dim # todo(note): unk is the first one above real words self.pred_size = min(conf.max_pred_rank + 1, vocab_word.unk) if conf.tie_input_embeddings: zwarn("Tie all preds in plm with input embeddings!!") self.l2r_pred = self.r2l_pred = None self.inputter_embed_node = self.inputter.embedder.get_node("word") else: self.l2r_pred = self.add_sub_node( "l2r_p", Affine(pc, self.pred_input_dim, self.pred_size, init_rop=NoDropRop())) if conf.tie_bidirect_pred: self.r2l_pred = self.l2r_pred else: self.r2l_pred = self.add_sub_node( "r2l_p", Affine(pc, self.pred_input_dim, self.pred_size, init_rop=NoDropRop())) self.inputter_embed_node = None if conf.init_pred_from_pretrain: npvec = vpack.get_emb("word") if npvec is None: zwarn( "Pretrained vector not provided, skip init pred embeddings!!" ) else: with BK.no_grad_env(): self.l2r_pred.ws[0].copy_( BK.input_real(npvec[:self.pred_size].T)) self.r2l_pred.ws[0].copy_( BK.input_real(npvec[:self.pred_size].T)) zlog( f"Init pred embeddings from pretrained vectors (size={self.pred_size})." )
def loss(self, insts: List[ParseInstance], enc_expr, final_valid_expr, go1_pack, training: bool, margin: float): # first do decoding and related preparation with BK.no_grad_env(): _, _, g_packs, p_packs = self.decode(insts, enc_expr, final_valid_expr, go1_pack, training, margin) # flatten the packs (remember to rebase the indexes) gold_pack = self._flatten_packs(g_packs) pred_pack = self._flatten_packs(p_packs) if self.filter_pruned: # filter out non-valid (pruned) edges, to avoid prune error mod_unpruned_mask, gold_mask = self.helper.get_unpruned_mask( final_valid_expr, gold_pack) pred_mask = mod_unpruned_mask[ pred_pack[0], pred_pack[1]] # filter by specific mod gold_pack = [(None if z is None else z[gold_mask]) for z in gold_pack] pred_pack = [(None if z is None else z[pred_mask]) for z in pred_pack] # calculate the scores for loss gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = gold_pack pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes = pred_pack gold_arc_score, gold_label_score_all = self._get_basic_score( enc_expr, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes) pred_arc_score, pred_label_score_all = self._get_basic_score( enc_expr, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes) # whether have labeled scores if self.system_labeled: gold_label_score = BK.gather_one_lastdim( gold_label_score_all, gold_lab_idxes).squeeze(-1) pred_label_score = BK.gather_one_lastdim( pred_label_score_all, pred_lab_idxes).squeeze(-1) ret_scores = (gold_arc_score, pred_arc_score, gold_label_score, pred_label_score) pred_full_scores, gold_full_scores = pred_arc_score + pred_label_score, gold_arc_score + gold_label_score else: ret_scores = (gold_arc_score, pred_arc_score) pred_full_scores, gold_full_scores = pred_arc_score, gold_arc_score # hinge loss: filter-margin by loss*margin to be aware of search error if self.filter_margin: with BK.no_grad_env(): mat_shape = BK.get_shape(enc_expr)[:2] # [bs, slen] heads_gold = self._get_tmp_mat(mat_shape, 0, BK.int64, gold_b_idxes, gold_m_idxes, gold_h_idxes) heads_pred = self._get_tmp_mat(mat_shape, 0, BK.int64, pred_b_idxes, pred_m_idxes, pred_h_idxes) error_count = (heads_gold != heads_pred).float() if self.system_labeled: labels_gold = self._get_tmp_mat(mat_shape, 0, BK.int64, gold_b_idxes, gold_m_idxes, gold_lab_idxes) labels_pred = self._get_tmp_mat(mat_shape, 0, BK.int64, pred_b_idxes, pred_m_idxes, pred_lab_idxes) error_count += (labels_gold != labels_pred).float() scores_gold = self._get_tmp_mat(mat_shape, 0., BK.float32, gold_b_idxes, gold_m_idxes, gold_full_scores) scores_pred = self._get_tmp_mat(mat_shape, 0., BK.float32, pred_b_idxes, pred_m_idxes, pred_full_scores) # todo(note): here, a small 0.1 is to exclude zero error: anyway they will get zero gradient sent_mask = ((scores_gold.sum(-1) - scores_pred.sum(-1)) <= (margin * error_count.sum(-1) + 0.1)).float() num_valid_sent = float(BK.get_value(sent_mask.sum())) final_loss_sum = ( pred_full_scores * sent_mask[pred_b_idxes] - gold_full_scores * sent_mask[gold_b_idxes]).sum() else: num_valid_sent = len(insts) final_loss_sum = (pred_full_scores - gold_full_scores).sum() # prepare final loss # divide loss by what? num_sent = len(insts) num_valid_tok = sum(len(z) for z in insts) if self.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = { "sent": num_sent, "sent_valid": num_valid_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val } return final_loss, ret_scores, info
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): with BK.no_grad_env(): self.refresh_batch(False) # ===== calculate scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score( insts, False) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) # normalizing scores full_score = None final_exp_score = False # whether to provide PROB by exp if self.norm_local and self.loss_prob: full_score = BK.log_softmax(full_arc_score, -1).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) final_exp_score = True elif self.norm_hlocal and self.loss_prob: # normalize at m dimension, ignore each nodes's self-finish step. full_score = BK.log_softmax(full_arc_score, -2).unsqueeze(-1) + BK.log_softmax( full_label_score, -1) elif self.norm_single and self.loss_prob: if self.conf.iconf.dec_single_neg: # todo(+2): add all-neg for prob explanation full_arc_probs = BK.sigmoid(full_arc_score) full_label_probs = BK.sigmoid(full_label_score) fake_arc_scores = BK.log(full_arc_probs) - BK.log( 1. - full_arc_probs) fake_label_scores = BK.log(full_label_probs) - BK.log( 1. - full_label_probs) full_score = fake_arc_scores.unsqueeze( -1) + fake_label_scores else: full_score = BK.logsigmoid(full_arc_score).unsqueeze( -1) + BK.logsigmoid(full_label_score) final_exp_score = True else: full_score = full_arc_score.unsqueeze(-1) + full_label_score # decode mst_lengths = [len(z) + 1 for z in insts ] # +=1 to include ROOT for mst decoding mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode( full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32)) if final_exp_score: mst_scores_arr = np.exp(mst_scores_arr) # jpos prediction (directly index, no converting as in parsing) jpos_preds_expr = jpos_pack[2] has_jpos_pred = jpos_preds_expr is not None jpos_preds_arr = BK.get_value( jpos_preds_expr) if has_jpos_pred else None # ===== assign info = {"sent": len(insts), "tok": sum(mst_lengths) - len(insts)} mst_real_labels = self.pred2real_labels(mst_labels_arr) for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_real_labels[one_idx][:cur_length], self.label_vocab) one_inst.pred_par_scores.set_vals( mst_scores_arr[one_idx][:cur_length]) if has_jpos_pred: one_inst.pred_poses.build_vals( jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab) return info