예제 #1
0
 def collect_loss_and_backward(self, loss_names, loss_ts, loss_lambdas,
                               info, training, loss_factor):
     final_losses = []
     for one_name, one_losses, one_lambda in zip(loss_names, loss_ts,
                                                 loss_lambdas):
         if one_lambda > 0. and len(one_losses) > 0:
             num_sub_losses = len(one_losses[0])
             coll_sub_losses = []
             for i in range(num_sub_losses):
                 # todo(note): (loss_sum, loss_count, gold_count[opt])
                 this_loss_sum = BK.stack([z[i][0]
                                           for z in one_losses]).sum()
                 this_loss_count = BK.stack([z[i][1]
                                             for z in one_losses]).sum()
                 info[f"loss_sum_{one_name}{i}"] = this_loss_sum.item()
                 info[f"loss_count_{one_name}{i}"] = this_loss_count.item()
                 # optional extra count
                 if len(one_losses[0][i]) >= 3:  # has gold count
                     info[f"loss_count_extra_{one_name}{i}"] = BK.stack(
                         [z[i][2] for z in one_losses]).sum().item()
                 # todo(note): any case that loss-count can be 0?
                 coll_sub_losses.append(this_loss_sum /
                                        (this_loss_count + 1e-5))
             # sub losses are already multiplied by sub-lambdas
             weighted_sub_loss = BK.stack(
                 coll_sub_losses).sum() * one_lambda
             final_losses.append(weighted_sub_loss)
     if len(final_losses) > 0:
         final_loss = BK.stack(final_losses).sum()
         if training and final_loss.requires_grad:
             BK.backward(final_loss, loss_factor)
예제 #2
0
 def get_losses_from_attn_list(list_attn_info: List, ts_f, loss_f,
                               loss_prefix, loss_lambda):
     loss_num = None
     loss_counts: List[int] = []
     loss_sums: List[List] = []
     rets = []
     # -----
     for one_attn_info in list_attn_info:  # each update step
         one_ts: List = ts_f(
             one_attn_info)  # get tensor list from attn_info
         # get number of losses
         if loss_num is None:
             loss_num = len(one_ts)
             loss_counts = [0] * loss_num
             loss_sums = [[] for _ in range(loss_num)]
         else:
             assert len(one_ts) == loss_num, "mismatched ts length"
         # iter them
         for one_t_idx, one_t in enumerate(
                 one_ts):  # iter on the tensor list
             one_loss = loss_f(one_t)
             # need it to be in the corresponding shape
             loss_counts[one_t_idx] += np.prod(
                 BK.get_shape(one_loss)).item()
             loss_sums[one_t_idx].append(one_loss.sum())
     # for different steps
     for i, one_loss_count, one_loss_sums in zip(range(len(loss_counts)),
                                                 loss_counts, loss_sums):
         loss_leaf = LossHelper.compile_leaf_info(
             f"{loss_prefix}{i}",
             BK.stack(one_loss_sums, 0).sum(),
             BK.input_real(one_loss_count),
             loss_lambda=loss_lambda)
         rets.append(loss_leaf)
     return rets
예제 #3
0
파일: mtl.py 프로젝트: ValentinaPy/zmsp
 def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None):
     conf = self.conf
     # -----
     # special mode
     if conf.aug_word2 and conf.aug_word2_aug_encoder:
         _rop = RefreshOptions(training=False)  # special feature-mode!!
         self.embedder.refresh(_rop)
         self.encoder.refresh(_rop)
     # -----
     emb_t, mask_t = self.embedder(cur_input_map)
     rel_dist = cur_input_map.get("rel_dist", None)
     if rel_dist is not None:
         rel_dist = BK.input_idx(rel_dist)
     if conf.enc_choice == "vrec":
         enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss)
     elif conf.enc_choice == "original":  # todo(note): change back to arr for back compatibility
         assert rel_dist is None, "Original encoder does not support rel_dist"
         enc_t = self.encoder(emb_t, BK.get_value(mask_t))
         cache, enc_loss = None, None
     else:
         raise NotImplementedError()
     # another encoder based on attn
     final_enc_t = self.rpreper(emb_t, enc_t, cache)  # [*, slen, D] => final encoder output
     if conf.aug_word2:
         emb2_t = self.aug_word2(insts)
         if conf.aug_word2_aug_encoder:
             # simply add them all together, detach orig-enc as features
             stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach()
             features = self.aug_mixturer(stack_hidden_t)
             aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features))
             final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t,
                                                             rel_dist=rel_dist, collect_loss=collect_loss)
         else:
             final_enc_t = (final_enc_t + emb2_t)  # otherwise, simply adding
     return emb_t, mask_t, final_enc_t, cache, enc_loss
예제 #4
0
 def reg_scores_loss(self, *scores):
     if self.reg_scores_lambda > 0.:
         sreg_losses = [(z**2).mean() for z in scores]
         if len(sreg_losses) > 0:
             sreg_loss = BK.stack(
                 sreg_losses).mean() * self.reg_scores_lambda
             return sreg_loss
     return None
예제 #5
0
 def forward_features(self, ids_expr, mask_expr, typeids_expr,
                      other_embed_exprs: List):
     bmodel = self.model
     bmodel_embedding = bmodel.embeddings
     bmodel_encoder = bmodel.encoder
     # prepare
     attention_mask = mask_expr
     token_type_ids = BK.zeros(BK.get_shape(
         ids_expr)).long() if typeids_expr is None else typeids_expr
     extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
     # extended_attention_mask = extended_attention_mask.to(dtype=next(bmodel.parameters()).dtype)  # fp16 compatibility
     extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
     # embeddings
     cur_layer = 0
     if self.trainable_min_layer <= 0:
         last_output = bmodel_embedding(ids_expr,
                                        position_ids=None,
                                        token_type_ids=token_type_ids)
     else:
         with BK.no_grad_env():
             last_output = bmodel_embedding(ids_expr,
                                            position_ids=None,
                                            token_type_ids=token_type_ids)
     # extra embeddings (this implies overall graident requirements!!)
     for one_eidx, one_embed in enumerate(self.other_embeds):
         last_output += one_embed(
             other_embed_exprs[one_eidx])  # [bs, slen, D]
     # =====
     all_outputs = []
     if self.layer_is_output[cur_layer]:
         all_outputs.append(last_output)
     cur_layer += 1
     # todo(note): be careful about the indexes!
     # not-trainable encoders
     trainable_min_layer_idx = max(0, self.trainable_min_layer - 1)
     with BK.no_grad_env():
         for layer_module in bmodel_encoder.layer[:trainable_min_layer_idx]:
             last_output = layer_module(last_output,
                                        extended_attention_mask, None)[0]
             if self.layer_is_output[cur_layer]:
                 all_outputs.append(last_output)
             cur_layer += 1
     # trainable encoders
     for layer_module in bmodel_encoder.layer[trainable_min_layer_idx:self.
                                              output_max_layer]:
         last_output = layer_module(last_output, extended_attention_mask,
                                    None)[0]
         if self.layer_is_output[cur_layer]:
             all_outputs.append(last_output)
         cur_layer += 1
     assert cur_layer == self.output_max_layer + 1
     # stack
     if len(all_outputs) == 1:
         ret_expr = all_outputs[0].unsqueeze(-2)
     else:
         ret_expr = BK.stack(all_outputs, -2)  # [BS, SLEN, LAYER, D]
     final_ret_exp = self.output_f(ret_expr)
     return final_ret_exp
예제 #6
0
 def _score(self, bert_expr, bidxes_t, hidxes_t):
     # ----
     # # debug
     # print(f"# ====\n Debug: {ArgSpanExpander._debug_count}")
     # ArgSpanExpander._debug_count += 1
     # ----
     bert_expr = bert_expr.view(BK.get_shape(bert_expr)[:-2] +
                                [-1])  # flatten
     #
     max_range = self.conf.max_range
     max_slen = BK.get_shape(bert_expr, 1)
     # get candidates
     range_t = BK.arange_idx(max_range).unsqueeze(0)  # [1, R]
     bidxes_t = bidxes_t.unsqueeze(1)  # [N, 1]
     hidxes_t = hidxes_t.unsqueeze(1)  # [N, 1]
     left_cands = hidxes_t - range_t  # [N, R]
     right_cands = hidxes_t + range_t
     left_masks = (left_cands >= 0).float()
     right_masks = (right_cands < max_slen).float()
     left_cands.clamp_(min=0)
     right_cands.clamp_(max=max_slen - 1)
     # score
     head_exprs = bert_expr[bidxes_t, hidxes_t]  # [N, 1, D']
     left_cand_exprs = bert_expr[bidxes_t, left_cands]  # [N, R, D']
     right_cand_exprs = bert_expr[bidxes_t, right_cands]
     # actual scoring
     if self.use_lstm_scorer:
         batch_size = BK.get_shape(bidxes_t, 0)
         all_concat_outputs = []
         for cand_exprs, lstm_node in zip(
             [left_cand_exprs, right_cand_exprs], [self.llstm, self.rlstm]):
             cur_state = lstm_node.zero_init_hidden(batch_size)
             step_size = BK.get_shape(cand_exprs, 1)
             all_outputs = []
             for step_i in range(step_size):
                 cur_state = lstm_node(cand_exprs[:, step_i], cur_state,
                                       None)
                 all_outputs.append(cur_state[0])  # using h
             concat_output = BK.stack(all_outputs, 1)  # [N, R, ?]
             all_concat_outputs.append(concat_output)
         left_hidden, right_hidden = all_concat_outputs
         left_scores = self.lscorer(left_hidden).squeeze(-1)  # [N, R]
         right_scores = self.rscorer(right_hidden).squeeze(-1)  # [N, R]
     else:
         left_scores = self.lscorer([left_cand_exprs,
                                     head_exprs]).squeeze(-1)  # [N, R]
         right_scores = self.rscorer([right_cand_exprs,
                                      head_exprs]).squeeze(-1)
     # mask
     left_scores += Constants.REAL_PRAC_MIN * (1. - left_masks)
     right_scores += Constants.REAL_PRAC_MIN * (1. - right_masks)
     return left_scores, right_scores
예제 #7
0
파일: base.py 프로젝트: ValentinaPy/zmsp
 def collect_loss_and_backward(self, loss_info_cols: List[Dict],
                               training: bool, loss_factor: float):
     final_loss_dict = LossHelper.combine_multiple(
         loss_info_cols)  # loss_name -> {}
     if len(final_loss_dict) <= 0:
         return {}  # no loss!
     final_losses = []
     ret_info_vals = OrderedDict()
     for loss_name, loss_info in final_loss_dict.items():
         final_losses.append(loss_info['sum'] / (loss_info['count'] + 1e-5))
         for k in loss_info.keys():
             one_item = loss_info[k]
             ret_info_vals[f"loss:{loss_name}_{k}"] = one_item.item(
             ) if hasattr(one_item, "item") else float(one_item)
     final_loss = BK.stack(final_losses).sum()
     if training and final_loss.requires_grad:
         BK.backward(final_loss, loss_factor)
     return ret_info_vals
예제 #8
0
파일: mtl.py 프로젝트: ValentinaPy/zmsp
 def __init__(self, conf: MtlMlmModelConf, vpack: VocabPackage):
     super().__init__(conf)
     # for easier checking
     self.word_vocab = vpack.get_voc("word")
     # components
     self.embedder = self.add_node("emb", EmbedderNode(self.pc, conf.emb_conf, vpack))
     self.inputter = Inputter(self.embedder, vpack)  # not a node
     self.emb_out_dim = self.embedder.get_output_dims()[0]
     self.enc_attn_count = conf.default_attn_count
     if conf.enc_choice == "vrec":
         self.encoder = self.add_component("enc", VRecEncoder(self.pc, self.emb_out_dim, conf.venc_conf))
         self.enc_attn_count = self.encoder.attn_count
     elif conf.enc_choice == "original":
         conf.oenc_conf._input_dim = self.emb_out_dim
         self.encoder = self.add_node("enc", MyEncoder(self.pc, conf.oenc_conf))
     else:
         raise NotImplementedError()
     zlog(f"Finished building model's encoder {self.encoder}, all size is {self.encoder.count_allsize_parameters()}")
     self.enc_out_dim = self.encoder.get_output_dims()[0]
     # --
     conf.rprep_conf._rprep_vr_conf.matt_conf.head_count = self.enc_attn_count  # make head-count agree
     self.rpreper = self.add_node("rprep", RPrepNode(self.pc, self.enc_out_dim, conf.rprep_conf))
     # --
     self.lambda_agree = self.add_scheduled_value(ScheduledValue(f"agr:lambda", conf.lambda_agree))
     self.agree_loss_f = EntropyHelper.get_method(conf.agree_loss_f)
     # --
     self.masklm = self.add_component("mlm", MaskLMNode(self.pc, self.enc_out_dim, conf.mlm_conf, self.inputter))
     self.plainlm = self.add_component("plm", PlainLMNode(self.pc, self.enc_out_dim, conf.plm_conf, self.inputter))
     # todo(note): here we use attn as dim_pair, do not use pair if not using vrec!!
     self.orderpr = self.add_component("orp", OrderPredNode(
         self.pc, self.enc_out_dim, self.enc_attn_count, conf.orp_conf, self.inputter))
     # =====
     # pre-training pre-load point!!
     if conf.load_pretrain_model_name:
         zlog(f"At preload_pretrain point: Loading from {conf.load_pretrain_model_name}")
         self.pc.load(conf.load_pretrain_model_name, strict=False)
     # =====
     self.dpar = self.add_component("dpar", DparG1Decoder(
         self.pc, self.enc_out_dim, self.enc_attn_count, conf.dpar_conf, self.inputter))
     self.upos = self.add_component("upos", SeqLabNode(
         self.pc, "pos", self.enc_out_dim, self.conf.upos_conf, self.inputter))
     if conf.do_ner:
         if conf.ner_use_crf:
             self.ner = self.add_component("ner", SeqCrfNode(
                 self.pc, "ner", self.enc_out_dim, self.conf.ner_conf, self.inputter))
         else:
             self.ner = self.add_component("ner", SeqLabNode(
                 self.pc, "ner", self.enc_out_dim, self.conf.ner_conf, self.inputter))
     else:
         self.ner = None
     # for pairwise reprs (no trainable params here!)
     self.rel_dist_embed = self.add_node("oremb", PosiEmbedding2(self.pc, n_dim=self.enc_attn_count, max_val=100))
     self._prepr_f_attn_sum = lambda cache, rdist: BK.stack(cache.list_attn, 0).sum(0) if (len(cache.list_attn))>0 else None
     self._prepr_f_attn_avg = lambda cache, rdist: BK.stack(cache.list_attn, 0).mean(0) if (len(cache.list_attn))>0 else None
     self._prepr_f_attn_max = lambda cache, rdist: BK.stack(cache.list_attn, 0).max(0)[0] if (len(cache.list_attn))>0 else None
     self._prepr_f_attn_last = lambda cache, rdist: cache.list_attn[-1] if (len(cache.list_attn))>0 else None
     self._prepr_f_rdist = lambda cache, rdist: self._get_rel_dist_embed(rdist, False)
     self._prepr_f_rdist_abs = lambda cache, rdist: self._get_rel_dist_embed(rdist, True)
     self.prepr_f = getattr(self, "_prepr_f_"+conf.prepr_choice)  # shortcut
     # --
     self.testing_rand_gen = Random.create_sep_generator(conf.testing_rand_gen_seed)  # especial gen for testing
     # =====
     if conf.orp_loss_special:
         self.orderpr.add_node_special(self.masklm)
     # =====
     # extra one!!
     self.aug_word2 = self.aug_encoder = self.aug_mixturer = None
     if conf.aug_word2:
         self.aug_word2 = self.add_node("aug2", AugWord2Node(self.pc, conf.emb_conf, vpack,
                                                             "word2", conf.aug_word2_dim, self.emb_out_dim))
         if conf.aug_word2_aug_encoder:
             assert conf.enc_choice == "vrec"
             self.aug_detach_drop = self.add_node("dd", Dropout(self.pc, (self.enc_out_dim,), fix_rate=conf.aug_detach_dropout))
             self.aug_encoder = self.add_component("Aenc", VRecEncoder(self.pc, self.emb_out_dim, conf.venc_conf))
             self.aug_mixturer = self.add_node("Amix", BertFeaturesWeightLayer(self.pc, conf.aug_detach_numlayer))
예제 #9
0
 def __init__(self,
              pc,
              conf: HLabelNodeConf,
              hl_vocab: HLabelVocab,
              eff_max_layer=None):
     super().__init__(pc, None, None)
     self.conf = conf
     self.hl_vocab = hl_vocab
     assert self.hl_vocab.nil_as_zero  # for each layer, the idx=0 is the full-NIL
     # basic pool embeddings
     npvec = hl_vocab.pool_init_vec
     if not conf.pool_init_hint:
         npvec = None
     else:
         assert npvec is not None, "pool-init not provided by the Vocab!"
     n_dim, n_pool = conf.n_dim, len(hl_vocab.pools_k)
     self.pool_pred = self.add_sub_node(
         "pp",
         Embedding(
             pc,
             n_pool,
             n_dim,
             fix_row0=conf.zero_nil,
             npvec=npvec,
             init_rop=(NoDropRop() if conf.nodrop_pred_embeds else None)))
     if conf.tie_embeds:
         self.pool_lookup = self.pool_pred
     else:
         self.pool_lookup = self.add_sub_node(
             "pl",
             Embedding(pc,
                       n_pool,
                       n_dim,
                       fix_row0=conf.zero_nil,
                       npvec=npvec,
                       init_rop=(NoDropRop()
                                 if conf.nodrop_lookup_embeds else None)))
     # layered labels embeddings (to be refreshed)
     self.max_layer = hl_vocab.max_layer
     self.layered_embeds_pred = [None] * self.max_layer
     self.layered_embeds_lookup = [None] * self.max_layer
     self.layered_prei = [
         None
     ] * self.max_layer  # previous layer i, for score combining
     self.layered_isnil = [None] * self.max_layer  # whether is nil(None)
     self.zero_nil = conf.zero_nil
     # lookup summer
     assert conf.strategy_predict == "sum"
     self.lookup_is_sum, self.lookup_is_ff = [
         conf.strategy_lookup == z for z in ["sum", "ff"]
     ]
     if self.lookup_is_ff:
         self.lookup_summer = self.add_sub_node(
             "summer",
             Affine(pc, [n_dim] * self.max_layer, n_dim, act="tanh"))
     elif self.lookup_is_sum:
         self.sum_dropout = self.add_sub_node("sdrop",
                                              Dropout(pc, (n_dim, )))
         self.lookup_summer = lambda embeds: self.sum_dropout(
             BK.stack(embeds, 0).sum(0))
     else:
         raise NotImplementedError(
             f"UNK strategy_lookup: {conf.strategy_lookup}")
     # bias for prediction
     self.prediction_sizes = [
         len(hl_vocab.layered_pool_links_padded[i])
         for i in range(self.max_layer)
     ]
     if conf.bias_predict:
         self.biases_pred = [
             self.add_param(name="B", shape=(x, ))
             for x in self.prediction_sizes
         ]
     else:
         self.biases_pred = [None] * self.max_layer
     # =====
     # training
     self.is_hinge_loss, self.is_prob_loss = [
         conf.loss_function == z for z in ["hinge", "prob"]
     ]
     self.loss_lambdas = conf.loss_lambdas + [1.] * (
         self.max_layer - len(conf.loss_lambdas))  # loss scale
     self.margin_lambdas = conf.margin_lambdas + [0.] * (
         self.max_layer - len(conf.margin_lambdas))  # margin scale
     self.lookup_soft_alphas = conf.lookup_soft_alphas + [1.] * (
         self.max_layer - len(conf.lookup_soft_alphas))
     self.loss_fullnil_weight = conf.loss_fullnil_weight
     # ======
     # set current effective max_layer
     self.eff_max_layer = self.max_layer
     if eff_max_layer is not None:
         self.set_eff_max_layer(eff_max_layer)
예제 #10
0
 def loss(self, input_expr, loss_mask, gold_idxes, margin=0.):
     gold_all_idxes = self._get_all_idxes(gold_idxes)
     # scoring
     raw_scores = self._raw_scores(input_expr)
     raw_scores_aug = []
     margin_P, margin_R, margin_T = self.conf.margin_lambda_P, self.conf.margin_lambda_R, self.conf.margin_lambda_T
     #
     gold_shape = BK.get_shape(gold_idxes)  # [*]
     gold_bsize_prod = np.prod(gold_shape)
     # gold_arange_idxes = BK.arange_idx(gold_bsize_prod)
     # margin
     for i in range(self.eff_max_layer):
         cur_gold_inputs = gold_all_idxes[i]
         # add margin
         cur_scores = raw_scores[i]  # [*, ?]
         cur_margin = margin * self.margin_lambdas[i]
         if cur_margin > 0.:
             cur_num_target = self.prediction_sizes[i]
             cur_isnil = self.layered_isnil[i].byte()  # [NLab]
             cost_matrix = BK.constants([cur_num_target, cur_num_target],
                                        margin_T)  # [gold, pred]
             cost_matrix[cur_isnil, :] = margin_P
             cost_matrix[:, cur_isnil] = margin_R
             diag_idxes = BK.arange_idx(cur_num_target)
             cost_matrix[diag_idxes, diag_idxes] = 0.
             margin_mat = cost_matrix[cur_gold_inputs]
             cur_aug_scores = cur_scores + margin_mat  # [*, ?]
         else:
             cur_aug_scores = cur_scores
         raw_scores_aug.append(cur_aug_scores)
     # cascade scores
     final_scores = self._cascade_scores(raw_scores_aug)
     # loss weight, todo(note): asserted self.hl_vocab.nil_as_zero before
     loss_weights = ((gold_idxes == 0).float() *
                     (self.loss_fullnil_weight - 1.) +
                     1.) if self.loss_fullnil_weight < 1. else 1.
     # calculate loss
     loss_prob_entropy_lambda = self.conf.loss_prob_entropy_lambda
     loss_prob_reweight = self.conf.loss_prob_reweight
     final_losses = []
     no_loss_max_gold = self.conf.no_loss_max_gold
     if loss_mask is None:
         loss_mask = BK.constants(BK.get_shape(input_expr)[:-1], 1.)
     for i in range(self.eff_max_layer):
         cur_final_scores, cur_gold_inputs = final_scores[
             i], gold_all_idxes[i]  # [*, ?], [*]
         # collect the loss
         if self.is_hinge_loss:
             cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1)
             cur_gold_scores = BK.gather(cur_final_scores,
                                         cur_gold_inputs.unsqueeze(-1),
                                         -1).squeeze(-1)
             cur_loss = cur_pred_scores - cur_gold_scores  # [*], todo(note): this must be >=0
             if no_loss_max_gold:  # this should be implicit
                 cur_loss = cur_loss * (cur_loss > 0.).float()
         elif self.is_prob_loss:
             # cur_loss = BK.loss_nll(cur_final_scores, cur_gold_inputs)  # [*]
             cur_loss = self._my_loss_prob(cur_final_scores,
                                           cur_gold_inputs,
                                           loss_prob_entropy_lambda,
                                           loss_mask,
                                           loss_prob_reweight)  # [*]
             if no_loss_max_gold:
                 cur_pred_scores, cur_pred_idxes = cur_final_scores.max(-1)
                 cur_gold_scores = BK.gather(cur_final_scores,
                                             cur_gold_inputs.unsqueeze(-1),
                                             -1).squeeze(-1)
                 cur_loss = cur_loss * (cur_gold_scores >
                                        cur_pred_scores).float()
         else:
             raise NotImplementedError(
                 f"UNK loss {self.conf.loss_function}")
         # here first summing up, divided at the outside
         one_loss_sum = (
             cur_loss *
             (loss_mask * loss_weights)).sum() * self.loss_lambdas[i]
         final_losses.append(one_loss_sum)
     # final sum
     final_loss_sum = BK.stack(final_losses).sum()
     _, ret_lab_idxes, ret_lab_embeds = self._predict(final_scores, None)
     return [[final_loss_sum,
              loss_mask.sum()]], ret_lab_idxes, ret_lab_embeds
예제 #11
0
 def run(self, insts: List[DocInstance], training: bool):
     conf = self.conf
     BERT_MAX_LEN = 510  # save 2 for CLS and SEP
     # =====
     # encoder 1: the basic encoder
     # todo(note): only DocInstane input for this mode, otherwise will break
     if conf.m2e_use_basic:
         reidx_pad_len = conf.ms_extend_budget
         # enc the basic part + also get some indexes
         sentid2offset = {}  # id(sent)->overall_seq_offset
         seq_offset = 0  # if look at the docs in one seq
         all_sents = []  # (inst, d_idx, s_idx)
         for d_idx, one_doc in enumerate(insts):
             assert isinstance(one_doc, DocInstance)
             for s_idx, one_sent in enumerate(one_doc.sents):
                 # todo(note): here we encode all the sentences
                 all_sents.append((one_sent, d_idx, s_idx))
                 sentid2offset[id(one_sent)] = seq_offset
                 seq_offset += one_sent.length - 1  # exclude extra ROOT node
         sent_reprs = self.run_sents(all_sents, insts, training)
         # flatten and concatenate and re-index
         reidxes_arr = np.zeros(
             seq_offset + reidx_pad_len, dtype=np.long
         )  # todo(note): extra padding to avoid out of boundary
         all_flattened_reprs = []
         all_flatten_offset = 0  # the local offset for batched basic encoding
         for one_pack in sent_reprs:
             one_sents, _, one_repr_ef, one_repr_evt, _ = one_pack
             assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
             one_repr_t = one_repr_evt
             _, one_slen, one_ldim = BK.get_shape(one_repr_t)
             all_flattened_reprs.append(one_repr_t.view([-1, one_ldim]))
             # fill in the indexes
             for one_sent in one_sents:
                 cur_start_offset = sentid2offset[id(one_sent)]
                 cur_real_slen = one_sent.length - 1
                 # again, +1 to get rid of extra ROOT
                 reidxes_arr[cur_start_offset:cur_start_offset+cur_real_slen] = \
                     np.arange(cur_real_slen, dtype=np.long) + (all_flatten_offset+1)
                 all_flatten_offset += one_slen  # here add the slen in batched version
         # re-idxing
         seq_sent_repr0 = BK.concat(all_flattened_reprs, 0)
         seq_sent_repr = BK.select(seq_sent_repr0, reidxes_arr,
                                   0)  # [all_seq_len, D]
     else:
         sentid2offset = defaultdict(int)
         seq_sent_repr = None
     # =====
     # repack and prepare for multiple sent enc
     # todo(note): here, the criterion is based on bert's tokenizer
     all_ms_info = []
     if isinstance(insts[0], DocInstance):
         for d_idx, one_doc in enumerate(insts):
             for s_idx, x in enumerate(one_doc.sents):
                 # the basic criterion is the same as the basic one
                 include_flag = False
                 if training:
                     if x.length<self.train_skip_length and x.length>=self.train_min_length \
                             and (len(x.events)>0 or next(self.random_sample_stream)>self.train_skip_noevt_rate):
                         include_flag = True
                 else:
                     if x.length >= self.test_min_length:
                         include_flag = True
                 if include_flag:
                     all_ms_info.append(
                         x.preps["ms"])  # use the pre-calculated one
     else:
         # multisent based
         all_ms_info = insts.copy()  # shallow copy
     # =====
     # encoder 2: the bert one (multi-sent encoding)
     ms_size_f = lambda x: x.subword_size
     all_ms_info.sort(key=ms_size_f)
     all_ms_buckets = self._bucket_sents_by_length(
         all_ms_info,
         conf.benc_bucket_range,
         ms_size_f,
         max_bsize=conf.benc_bucket_msize)
     berter = self.berter
     rets = []
     bert_use_center_typeids = conf.bert_use_center_typeids
     bert_use_special_typeids = conf.bert_use_special_typeids
     bert_other_inputs = conf.bert_other_inputs
     for one_bucket in all_ms_buckets:
         # prepare
         batched_ids = []
         batched_starts = []
         batched_seq_offset = []
         batched_typeids = []
         batched_other_inputs_list: List = [
             [] for _ in bert_other_inputs
         ]  # List(comp) of List(batch) of List(idx)
         for one_item in one_bucket:
             one_sents = one_item.sents
             one_center_sid = one_item.center_idx
             one_ids, one_starts, one_typeids = [], [], []
             one_other_inputs_list = [[] for _ in bert_other_inputs
                                      ]  # List(comp) of List(idx)
             for one_sid, one_sent in enumerate(one_sents):  # for bert
                 one_bidxes = one_sent.preps["bidx"]
                 one_ids.extend(one_bidxes.subword_ids)
                 one_starts.extend(one_bidxes.subword_is_start)
                 # prepare other inputs
                 for this_field_name, this_tofill_list in zip(
                         bert_other_inputs, one_other_inputs_list):
                     this_tofill_list.extend(
                         one_sent.preps["sub_" + this_field_name])
                 # todo(note): special procedure
                 if bert_use_center_typeids:
                     if one_sid != one_center_sid:
                         one_typeids.extend([0] *
                                            len(one_bidxes.subword_ids))
                     else:
                         this_typeids = [1] * len(one_bidxes.subword_ids)
                         if bert_use_special_typeids:
                             # todo(note): this is the special mode that we are given the events!!
                             for this_event in one_sents[
                                     one_center_sid].events:
                                 _, this_wid, this_wlen = this_event.mention.hard_span.position(
                                     headed=False)
                                 for a, b in one_item.center_word2sub[
                                         this_wid - 1:this_wid - 1 +
                                         this_wlen]:
                                     this_typeids[a:b] = [0] * (b - a)
                         one_typeids.extend(this_typeids)
             batched_ids.append(one_ids)
             batched_starts.append(one_starts)
             batched_typeids.append(one_typeids)
             for comp_one_oi, comp_batched_oi in zip(
                     one_other_inputs_list, batched_other_inputs_list):
                 comp_batched_oi.append(comp_one_oi)
             # for basic part
             batched_seq_offset.append(sentid2offset[id(one_sents[0])])
         # bert forward: [bs, slen, fold, D]
         if not bert_use_center_typeids:
             batched_typeids = None
         bert_expr0, mask_expr = berter.forward_batch(
             batched_ids,
             batched_starts,
             batched_typeids,
             training=training,
             other_inputs=batched_other_inputs_list)
         if self.m3_enc_is_empty:
             bert_expr = bert_expr0
         else:
             mask_arr = BK.get_value(mask_expr)  # [bs, slen]
             m3e_exprs = [
                 cur_enc(bert_expr0[:, :, cur_i], mask_arr)
                 for cur_i, cur_enc in enumerate(self.m3_encs)
             ]
             bert_expr = BK.stack(m3e_exprs, -2)  # on the fold dim again
         # collect basic ones: [bs, slen, D'] or None
         if seq_sent_repr is not None:
             arange_idxes_t = BK.arange_idx(BK.get_shape(
                 mask_expr, -1)).unsqueeze(0)  # [1, slen]
             offset_idxes_t = BK.input_idx(batched_seq_offset).unsqueeze(
                 -1) + arange_idxes_t  # [bs, slen]
             basic_expr = seq_sent_repr[offset_idxes_t]  # [bs, slen, D']
         elif conf.m2e_use_basic_dep:
             # collect each token's head-bert and ud-label, then forward with adp
             fake_sents = [one_item.fake_sent for one_item in one_bucket]
             # head idx and labels, no artificial ROOT
             padded_head_arr, _ = self.dep_padder.pad(
                 [s.ud_heads.vals[1:] for s in fake_sents])
             padded_label_arr, _ = self.dep_padder.pad(
                 [s.ud_labels.idxes[1:] for s in fake_sents])
             # get tensor
             padded_head_t = (BK.input_idx(padded_head_arr) - 1
                              )  # here, the idx exclude root
             padded_head_t.clamp_(min=0)  # [bs, slen]
             padded_label_t = BK.input_idx(padded_label_arr)
             # get inputs
             input_head_bert_t = bert_expr[
                 BK.arange_idx(len(fake_sents)).unsqueeze(-1),
                 padded_head_t]  # [bs, slen, fold, D]
             input_label_emb_t = self.dep_label_emb(
                 padded_label_t)  # [bs, slen, D']
             basic_expr = self.dep_layer(
                 input_head_bert_t, None,
                 [input_label_emb_t])  # [bs, slen, ?]
         elif conf.m2e_use_basic_plus:
             sent_reprs = self.run_sents([(one_item.fake_sent, None, None)
                                          for one_item in one_bucket],
                                         insts,
                                         training,
                                         use_one_bucket=True)
             assert len(
                 sent_reprs
             ) == 1, "Unsupported split reprs for basic encoder, please set enc_bucket_range<=benc_bucket_range"
             _, _, one_repr_ef, one_repr_evt, _ = sent_reprs[0]
             assert one_repr_ef is one_repr_evt, "Currently does not support separate basic enc in m3 mode"
             basic_expr = one_repr_evt[:, 1:]  # exclude ROOT, [bs, slen, D]
             assert BK.get_shape(basic_expr)[:2] == BK.get_shape(
                 bert_expr)[:2]
         else:
             basic_expr = None
         # pack: (List[ms_item], bert_expr, basic_expr)
         rets.append((one_bucket, bert_expr, basic_expr))
     return rets
예제 #12
0
 def __init__(self, pc: BK.ParamCollection, input_dim: int,
              conf: MaskLMNodeConf, inputter: Inputter):
     super().__init__(pc, conf, name="MLM")
     self.conf = conf
     self.inputter = inputter
     self.input_dim = input_dim
     # this step is performed at the embedder, thus still does not influence the inputter
     self.add_root_token = self.inputter.embedder.add_root_token
     # vocab and padder
     vpack = inputter.vpack
     vocab_word, vocab_pos = vpack.get_voc("word"), vpack.get_voc("pos")
     # no mask fields
     self.nomask_names_set = set(conf.nomask_names)
     # models
     if conf.hid_dim <= 0:  # no hidden layer
         self.hid_layer = None
         self.pred_input_dim = input_dim
     else:
         self.hid_layer = self.add_sub_node(
             "hid", Affine(pc, input_dim, conf.hid_dim, act=conf.hid_act))
         self.pred_input_dim = conf.hid_dim
     # todo(note): unk is the first one above real words
     self.pred_word_size = min(conf.max_pred_rank + 1, vocab_word.unk)
     self.pred_pos_size = vocab_pos.unk
     if conf.tie_input_embeddings:
         zwarn("Tie all preds in mlm with input embeddings!!")
         self.pred_word_layer = self.pred_pos_layer = None
         self.inputter_word_node = self.inputter.embedder.get_node("word")
         self.inputter_pos_node = self.inputter.embedder.get_node("pos")
     else:
         self.inputter_word_node, self.inputter_pos_node = None, None
         self.pred_word_layer = self.add_sub_node(
             "pw",
             Affine(pc,
                    self.pred_input_dim,
                    self.pred_word_size,
                    init_rop=NoDropRop()))
         self.pred_pos_layer = self.add_sub_node(
             "pp",
             Affine(pc,
                    self.pred_input_dim,
                    self.pred_pos_size,
                    init_rop=NoDropRop()))
         if conf.init_pred_from_pretrain:
             npvec = vpack.get_emb("word")
             if npvec is None:
                 zwarn(
                     "Pretrained vector not provided, skip init pred embeddings!!"
                 )
             else:
                 with BK.no_grad_env():
                     self.pred_word_layer.ws[0].copy_(
                         BK.input_real(npvec[:self.pred_word_size].T))
                 zlog(
                     f"Init pred embeddings from pretrained vectors (size={self.pred_word_size})."
                 )
     # =====
     COMBINE_METHOD_FS = {
         "sum": lambda xs: BK.stack(xs, -1).sum(-1),
         "avg": lambda xs: BK.stack(xs, -1).mean(-1),
         "min": lambda xs: BK.stack(xs, -1).min(-1)[0],
         "max": lambda xs: BK.stack(xs, -1).max(-1)[0],
     }
     self.loss_comb_f = COMBINE_METHOD_FS[conf.loss_comb_method]
     self.score_comb_f = COMBINE_METHOD_FS[conf.score_comb_method]