Example #1
0
 def _loss(self, enc_repr, action_list: List[EfAction], arc_weight_list: List[float], label_weight_list: List[float],
           bidxes_list: List[int]):
     # 1. collect (batched) features; todo(note): use prev state for scoring
     hm_features = self.hm_feature_getter.get_hm_features(action_list, [a.state_from for a in action_list])
     # 2. get new sreprs
     scorer = self.scorer
     s_enc = self.slayer
     bsize_range_t = BK.input_idx(bidxes_list)
     node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], enc_repr, bsize_range_t)
     node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], enc_repr, bsize_range_t)
     # label loss
     if self.system_labeled:
         node_lh_expr, _ = scorer.transform_space_label(node_h_srepr, True, False)
         _, node_lm_pack = scorer.transform_space_label(node_m_srepr, False, True)
         label_scores_full = scorer.score_label(node_lm_pack, node_lh_expr)  # [*, Lab]
         label_scores = BK.gather_one_lastdim(label_scores_full, [a.label for a in action_list]).squeeze(-1)
         final_label_loss_sum = (label_scores * BK.input_real(label_weight_list)).sum()
     else:
         label_scores = final_label_loss_sum = BK.zeros([])
     # arc loss
     node_ah_expr, _ = scorer.transform_space_arc(node_h_srepr, True, False)
     _, node_am_pack = scorer.transform_space_arc(node_m_srepr, False, True)
     arc_scores = scorer.score_arc(node_am_pack, node_ah_expr).squeeze(-1)
     final_arc_loss_sum = (arc_scores * BK.input_real(arc_weight_list)).sum()
     # score reg
     return final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores
Example #2
0
 def loss(self, ms_items: List, bert_expr):
     conf = self.conf
     max_range = self.conf.max_range
     bsize = len(ms_items)
     # collect instances
     col_efs, _, col_bidxes_t, col_hidxes_t, col_ldists_t, col_rdists_t = self._collect_insts(
         ms_items, True)
     if len(col_efs) == 0:
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz], [zzz, zzz, zzz]]
     left_scores, right_scores = self._score(bert_expr, col_bidxes_t,
                                             col_hidxes_t)  # [N, R]
     if conf.use_binary_scorer:
         left_binaries, right_binaries = (BK.arange_idx(max_range)<=col_ldists_t.unsqueeze(-1)).float(), \
                                         (BK.arange_idx(max_range)<=col_rdists_t.unsqueeze(-1)).float()  # [N,R]
         left_losses = BK.binary_cross_entropy_with_logits(
             left_scores, left_binaries, reduction='none')[:, 1:]
         right_losses = BK.binary_cross_entropy_with_logits(
             right_scores, right_binaries, reduction='none')[:, 1:]
         left_count = right_count = BK.input_real(
             BK.get_shape(left_losses, 0) * (max_range - 1))
     else:
         left_losses = BK.loss_nll(left_scores, col_ldists_t)
         right_losses = BK.loss_nll(right_scores, col_rdists_t)
         left_count = right_count = BK.input_real(
             BK.get_shape(left_losses, 0))
     return [[left_losses.sum(), left_count, left_count],
             [right_losses.sum(), right_count, right_count]]
Example #3
0
 def batch_inputs_h(self, insts: List[Sentence]):
     key, items_getter = self.extract_type, self.items_getter
     nil_idx = 0
     # get gold/input data and batch
     all_masks, all_idxes, all_items, all_valid = [], [], [], []
     all_idxes2, all_items2 = [], []  # secondary types
     for sent in insts:
         preps = sent.preps.get(key)
         # not cached, rebuild them
         if preps is None:
             length = sent.length
             items = items_getter(sent)
             # token-idx -> ...
             prep_masks, prep_idxes, prep_items = [0.] * length, [
                 nil_idx
             ] * length, [None] * length
             prep_idxes2, prep_items2 = [nil_idx] * length, [None] * length
             if items is None:
                 # todo(note): there are samples that do not have entity annotations (KBP15)
                 #  final 0/1 indicates valid or not
                 prep_valid = 0.
             else:
                 prep_valid = 1.
                 for one_item in items:
                     this_hwidx = one_item.mention.hard_span.head_wid
                     this_hlidx = one_item.type_idx
                     # todo(+N): ignore except the first two types (already ranked by type-freq)
                     if prep_idxes[this_hwidx] == 0:
                         prep_masks[this_hwidx] = 1.
                         prep_idxes[this_hwidx] = self.hlidx2idx(
                             this_hlidx)  # change to int here!
                         prep_items[this_hwidx] = one_item
                     elif prep_idxes2[this_hwidx] == 0:
                         prep_idxes2[this_hwidx] = self.hlidx2idx(
                             this_hlidx)  # change to int here!
                         prep_items2[this_hwidx] = one_item
             sent.preps[key] = (prep_masks, prep_idxes, prep_items,
                                prep_valid, prep_idxes2, prep_items2)
         else:
             prep_masks, prep_idxes, prep_items, prep_valid, prep_idxes2, prep_items2 = preps
         # =====
         all_masks.append(prep_masks)
         all_idxes.append(prep_idxes)
         all_items.append(prep_items)
         all_valid.append(prep_valid)
         all_idxes2.append(prep_idxes2)
         all_items2.append(prep_items2)
     # pad and batch
     mention_masks = BK.input_real(
         self.padder_mask.pad(all_masks)[0])  # [*, slen]
     mention_idxes = BK.input_idx(
         self.padder_idxes.pad(all_idxes)[0])  # [*, slen]
     mention_items_arr, _ = self.padder_items.pad(all_items)  # [*, slen]
     mention_valid = BK.input_real(all_valid)  # [*]
     mention_idxes2 = BK.input_idx(
         self.padder_idxes.pad(all_idxes2)[0])  # [*, slen]
     mention_items2_arr, _ = self.padder_items.pad(all_items2)  # [*, slen]
     return mention_masks, mention_idxes, mention_items_arr, mention_valid, mention_idxes2, mention_items2_arr
Example #4
0
 def batch_inputs_g0(self, insts: List[Sentence]):
     # similar to "batch_inputs_h", but further extend for each label
     key, items_getter = self.extract_type, self.items_getter
     # nil_idx = 0
     # get gold/input data and batch
     output_size = self.hl_output_size
     all_masks, all_items, all_valid = [], [], []
     for sent in insts:
         preps = sent.preps.get(key)
         # not cached, rebuild them
         if preps is None:
             length = sent.length
             items = items_getter(sent)
             # token-idx -> [slen, out-size]
             prep_masks = [[0. for _i1 in range(output_size)]
                           for _i0 in range(length)]
             prep_items = [[None for _i1 in range(output_size)]
                           for _i0 in range(length)]
             if items is None:
                 # todo(note): there are samples that do not have entity annotations (KBP15)
                 #  final 0/1 indicates valid or not
                 prep_valid = 0.
             else:
                 prep_valid = 1.
                 for one_item in items:
                     this_hwidx = one_item.mention.hard_span.head_wid
                     this_hlidx = one_item.type_idx
                     this_tidx = self.hlidx2idx(
                         this_hlidx)  # change to int here!
                     # todo(+N): simply ignore repeated ones with same type and trigger
                     if prep_masks[this_hwidx][this_tidx] == 0.:
                         prep_masks[this_hwidx][this_tidx] = 1.
                         prep_items[this_hwidx][this_tidx] = one_item
             sent.preps[key] = (prep_masks, prep_items, prep_valid)
         else:
             prep_masks, prep_items, prep_valid = preps
         # =====
         all_masks.append(prep_masks)
         all_items.append(prep_items)
         all_valid.append(prep_valid)
     # pad and batch
     mention_masks = BK.input_real(
         self.padder_mask_3d.pad(all_masks)[0])  # [*, slen, L]
     mention_idxes = None
     mention_items_arr, _ = self.padder_items_3d.pad(
         all_items)  # [*, slen, L]
     mention_valid = BK.input_real(all_valid)  # [*]
     return mention_masks, mention_idxes, mention_items_arr, mention_valid
Example #5
0
def main():
    pc = BK.ParamCollection()
    N_BATCH, N_SEQ = 8, 4
    N_HIDDEN, N_LAYER = 5, 3
    N_INPUT = N_HIDDEN
    N_FF = 10
    # encoders
    rnn_encoder = layers.RnnLayerBatchFirstWrapper(pc, layers.RnnLayer(pc, N_INPUT, N_HIDDEN, N_LAYER, bidirection=True))
    cnn_encoder = layers.Sequential(pc, [layers.CnnLayer(pc, N_INPUT, N_HIDDEN, 3, act="relu") for _ in range(N_LAYER)])
    att_encoder = layers.Sequential(pc, [layers.TransformerEncoderLayer(pc, N_INPUT, N_FF) for _ in range(N_LAYER)])
    dropout_md = layers.DropoutLastN(pc)
    #
    rop = layers.RefreshOptions(hdrop=0.2, gdrop=0.2, dropmd=0.2, fix_drop=True)
    rnn_encoder.refresh(rop)
    cnn_encoder.refresh(rop)
    att_encoder.refresh(rop)
    dropout_md.refresh(rop)
    #
    x = BK.input_real(np.random.randn(N_BATCH, N_SEQ, N_INPUT))
    x_mask = np.asarray([[1.]*z+[0.]*(N_SEQ-z) for z in np.random.randint(N_SEQ//2, N_SEQ, N_BATCH)])
    y_rnn = rnn_encoder(x, x_mask)
    y_cnn = cnn_encoder(x, x_mask)
    y_att = att_encoder(x, x_mask)
    zz = dropout_md(y_att)
    print("The end.")
    pass
Example #6
0
 def fb_on_batch(self,
                 annotated_insts: List[ParseInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     # encode
     input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
         annotated_insts, training)
     mask_expr = BK.input_real(mask_arr)
     # the parsing loss
     arc_score = self.scorer_helper.score_arc(enc_repr)
     lab_score = self.scorer_helper.score_label(enc_repr)
     full_score = arc_score + lab_score
     parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr)
     # other loss?
     jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
     reg_loss = self.reg_scores_loss(arc_score, lab_score)
     #
     info["loss_parse"] = BK.get_value(parsing_loss).item()
     final_loss = parsing_loss
     if jpos_loss is not None:
         info["loss_jpos"] = BK.get_value(jpos_loss).item()
         final_loss = parsing_loss + jpos_loss
     if reg_loss is not None:
         final_loss = final_loss + reg_loss
     info["fb"] = 1
     if training:
         BK.backward(final_loss, loss_factor)
     return info
Example #7
0
 def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr):
     mask_idxes, mask_valids = BK.mask2idx(
         BK.input_real(pred_mask_repl_arr))  # [bsize, ?]
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz]]
     else:
         target_reprs = BK.gather_first_dims(repr_t, mask_idxes,
                                             1)  # [bsize, ?, *]
         target_hids = self.hid_layer(target_reprs)
         target_scores = self.pred_layer(target_hids)  # [bsize, ?, V]
         pred_idx_t = BK.input_idx(pred_idx_arr)  # [bsize, slen]
         target_idx_t = pred_idx_t.gather(-1, mask_idxes)  # [bsize, ?]
         target_idx_t[(mask_valids <
                       1.)] = 0  # make sure invalid ones in range
         # get loss
         pred_losses = BK.loss_nll(target_scores,
                                   target_idx_t)  # [bsize, ?]
         pred_loss_sum = (pred_losses * mask_valids).sum()
         pred_loss_count = mask_valids.sum()
         # argmax
         _, argmax_idxes = target_scores.max(-1)
         pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids
         pred_corr_count = pred_corrs.sum()
         return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
Example #8
0
 def get_losses_from_attn_list(list_attn_info: List, ts_f, loss_f,
                               loss_prefix, loss_lambda):
     loss_num = None
     loss_counts: List[int] = []
     loss_sums: List[List] = []
     rets = []
     # -----
     for one_attn_info in list_attn_info:  # each update step
         one_ts: List = ts_f(
             one_attn_info)  # get tensor list from attn_info
         # get number of losses
         if loss_num is None:
             loss_num = len(one_ts)
             loss_counts = [0] * loss_num
             loss_sums = [[] for _ in range(loss_num)]
         else:
             assert len(one_ts) == loss_num, "mismatched ts length"
         # iter them
         for one_t_idx, one_t in enumerate(
                 one_ts):  # iter on the tensor list
             one_loss = loss_f(one_t)
             # need it to be in the corresponding shape
             loss_counts[one_t_idx] += np.prod(
                 BK.get_shape(one_loss)).item()
             loss_sums[one_t_idx].append(one_loss.sum())
     # for different steps
     for i, one_loss_count, one_loss_sums in zip(range(len(loss_counts)),
                                                 loss_counts, loss_sums):
         loss_leaf = LossHelper.compile_leaf_info(
             f"{loss_prefix}{i}",
             BK.stack(one_loss_sums, 0).sum(),
             BK.input_real(one_loss_count),
             loss_lambda=loss_lambda)
         rets.append(loss_leaf)
     return rets
Example #9
0
 def __call__(self,
              src,
              src_mask=None,
              qk_mask=None,
              rel_dist=None,
              forced_attns=None,
              collect_loss=False):
     if src_mask is not None:
         src_mask = BK.input_real(src_mask)
     if forced_attns is None:
         forced_attns = [None] * len(self.layers)
     # -----
     # forward
     temperature = self.temperature.value
     cur_hidden = src
     if len(self.layers) > 0:
         cache = self.layers[0].init_call(src)
         for one_lidx, one_layer in enumerate(self.layers):
             cur_hidden = one_layer.update_call(
                 cache,
                 src_mask=src_mask,
                 qk_mask=qk_mask,
                 attn_range=self.attn_ranges[one_lidx],
                 rel_dist=rel_dist,
                 temperature=temperature,
                 forced_attn=forced_attns[one_lidx])
     else:
         cache = VRecCache()  # empty one
     # -----
     # collect loss
     if collect_loss:
         loss_item = self._collect_losses(cache)
     else:
         loss_item = None
     return cur_hidden, cache, loss_item
Example #10
0
 def _score(self, insts: List[ParseInstance], training: bool,
            lambda_g1_arc: float, lambda_g1_lab: float):
     # encode
     input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
         insts, training)
     mask_expr = BK.input_real(mask_arr)
     # pruning and scores from g1
     valid_mask_d, valid_mask_s, go1_pack, arc_marginals = self._get_g1_pack(
         insts, mask_expr, lambda_g1_arc, lambda_g1_lab)
     # s-encode (using s-mask)
     final_valid_expr_s = self._make_final_valid(valid_mask_s, mask_expr)
     senc_repr = self.slayer(enc_repr, final_valid_expr_s, arc_marginals)
     # decode
     arc_score = self.scorer_helper.score_arc(senc_repr)
     lab_score = self.scorer_helper.score_label(senc_repr)
     full_score = arc_score + lab_score
     # add go1 scores and apply pruning (using d-mask)
     final_valid_expr_d = valid_mask_d.float(
     )  # no need to mask out others here!
     mask_value = Constants.REAL_PRAC_MIN
     if go1_pack is not None:
         go1_arc_score, go1_label_score = go1_pack
         full_score += go1_arc_score.unsqueeze(-1) + go1_label_score
     full_score += (mask_value * (1. - final_valid_expr_d)).unsqueeze(-1)
     # [*, m, h, lab], (original-scores), [*, m], [*, m, h]
     return full_score, (
         arc_score, lab_score
     ), jpos_pack, mask_expr, final_valid_expr_d, final_valid_expr_s
Example #11
0
 def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf,
              vpack: VocabPackage):
     super().__init__(pc, None, None)
     self.conf = conf
     # vocab and padder
     self.word_vocab = vpack.get_voc("word")
     self.padder = DataPadder(
         2, pad_vals=self.word_vocab.pad,
         mask_range=2)  # todo(note): <pad>-id is very large
     # models
     self.hid_layer = self.add_sub_node(
         "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act))
     self.pred_layer = self.add_sub_node(
         "pred",
         Affine(pc,
                conf.hid_dim,
                conf.max_pred_rank + 1,
                init_rop=NoDropRop()))
     if conf.init_pred_from_pretrain:
         npvec = vpack.get_emb("word")
         if npvec is None:
             zwarn(
                 "Pretrained vector not provided, skip init pred embeddings!!"
             )
         else:
             with BK.no_grad_env():
                 self.pred_layer.ws[0].copy_(
                     BK.input_real(npvec[:conf.max_pred_rank + 1].T))
             zlog(
                 f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})."
             )
Example #12
0
 def _score(self, repr_t, attn_t, mask_t):
     conf = self.conf
     # -----
     repr_m = self.pre_aff_m(repr_t)  # [bs, slen, S]
     repr_h = self.pre_aff_h(repr_t)  # [bs, slen, S]
     scores0 = self.dps_node.paired_score(
         repr_m, repr_h, inputp=attn_t)  # [bs, len_q, len_k, 1+N]
     # mask at outside
     slen = BK.get_shape(mask_t, -1)
     score_mask = BK.constants(BK.get_shape(scores0)[:-1],
                               1.)  # [bs, len_q, len_k]
     score_mask *= (1. - BK.eye(slen))  # no diag
     score_mask *= mask_t.unsqueeze(-1)  # input mask at len_k
     score_mask *= mask_t.unsqueeze(-2)  # input mask at len_q
     NEG = Constants.REAL_PRAC_MIN
     scores1 = scores0 + NEG * (1. - score_mask.unsqueeze(-1)
                                )  # [bs, len_q, len_k, 1+N]
     # add fixed idx0 scores if set
     if conf.fix_s0:
         fix_s0_mask_t = BK.input_real(self.dps_s0_mask)  # [1+N]
         scores1 = (
             1. - fix_s0_mask_t
         ) * scores1 + fix_s0_mask_t * conf.fix_s0_val  # [bs, len_q, len_k, 1+N]
     # minus s0
     if conf.minus_s0:
         scores1 = scores1 - scores1.narrow(-1, 0, 1)  # minus idx=0 scores
     return scores1, score_mask
Example #13
0
 def _prepare_score(self, insts, training):
     input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
         insts, training)
     mask_expr = BK.input_real(mask_arr)
     # am_expr, ah_expr, lm_expr, lh_expr = self.scorer.transform_space(enc_repr)
     scoring_expr_pack = self.scorer.transform_space(enc_repr)
     return scoring_expr_pack, mask_expr, jpos_pack
Example #14
0
def encode_bert(berter: Berter, sent: List[str], b_mask_mode, b_mask_repl):
    assert berter.bconf.bert_sent_extend == 0
    assert berter.bconf.bert_root_mode == -1
    # prepare inputs
    all_sentences = [berter.subword_tokenize(sent, True)]
    for i in range(len(sent)):
        all_sentences.append(
            berter.subword_tokenize(sent.copy(),
                                    True,
                                    mask_idx=i,
                                    mask_mode=b_mask_mode,
                                    mask_repl=b_mask_repl))
        # all_sentences.append(berter.subword_tokenize(sent.copy(), True, mask_idx=-1, mask_mode=b_mask_mode, mask_repl=b_mask_repl))
    # get outputs
    all_features = berter.extract_features(
        all_sentences)  # List[arr[1+slen, D]]
    # all_features = berter.extract_feature_simple_mode(all_sentences)  # List[arr[1+slen, D]]
    # =====
    # post-processing for pass mode
    if b_mask_mode == "pass":
        for i in range(1, len(all_features)):
            all_features[i] = np.insert(all_features[i], i, 0.,
                                        axis=0)  # [slen, D] -> [slen+1, D]
    return BK.input_real(np.stack(all_features,
                                  0))  # [1(whole)+slen, 1(R)+slen, D]
Example #15
0
 def __call__(self,
              word_arr: np.ndarray = None,
              char_arr: np.ndarray = None,
              extra_arrs: Iterable[np.ndarray] = (),
              aux_arrs: Iterable[np.ndarray] = ()):
     exprs = []
     # word/char/extras/posi
     seq_shape = None
     if self.has_word:
         # todo(warn): singleton-UNK-dropout should be done outside before
         seq_shape = word_arr.shape
         word_expr = self.dropmd_word(self.word_embed(word_arr))
         exprs.append(word_expr)
     if self.has_char:
         seq_shape = char_arr.shape[:-1]
         char_embeds = self.char_embed(
             char_arr)  # [*, seq-len, word-len, D]
         char_cat_expr = self.dropmd_char(
             BK.concat([z(char_embeds) for z in self.char_cnns]))
         exprs.append(char_cat_expr)
     zcheck(
         len(extra_arrs) == len(self.extra_embeds),
         "Unmatched extra fields.")
     for one_extra_arr, one_extra_embed, one_extra_dropmd in zip(
             extra_arrs, self.extra_embeds, self.dropmd_extras):
         seq_shape = one_extra_arr.shape
         exprs.append(one_extra_dropmd(one_extra_embed(one_extra_arr)))
     if self.has_posi:
         seq_len = seq_shape[-1]
         posi_idxes = BK.arange_idx(seq_len)
         posi_input0 = self.posi_embed(posi_idxes)
         for _ in range(len(seq_shape) - 1):
             posi_input0 = BK.unsqueeze(posi_input0, 0)
         posi_input1 = BK.expand(posi_input0, tuple(seq_shape) + (-1, ))
         exprs.append(posi_input1)
     #
     assert len(aux_arrs) == len(self.drop_auxes)
     for one_aux_arr, one_aux_dim, one_aux_drop, one_fold, one_gamma, one_lambdas in \
             zip(aux_arrs, self.dim_auxes, self.drop_auxes, self.fold_auxes, self.aux_overall_gammas, self.aux_fold_lambdas):
         # fold and apply trainable lambdas
         input_aux_repr = BK.input_real(one_aux_arr)
         input_shape = BK.get_shape(input_aux_repr)
         # todo(note): assume the original concat is [fold/layer, D]
         reshaped_aux_repr = input_aux_repr.view(
             input_shape[:-1] +
             [one_fold, one_aux_dim])  # [*, slen, fold, D]
         lambdas_softmax = BK.softmax(one_gamma,
                                      -1).unsqueeze(-1)  # [fold, 1]
         weighted_aux_repr = (reshaped_aux_repr * lambdas_softmax
                              ).sum(-2) * one_gamma  # [*, slen, D]
         one_aux_expr = one_aux_drop(weighted_aux_repr)
         exprs.append(one_aux_expr)
     #
     concated_exprs = BK.concat(exprs, dim=-1)
     # optional proj
     if self.has_proj:
         final_expr = self.final_layer(concated_exprs)
     else:
         final_expr = concated_exprs
     return final_expr
Example #16
0
 def fb_on_batch(self,
                 annotated_insts: List[ParseInstance],
                 training=True,
                 loss_factor=1.,
                 **kwargs):
     self.refresh_batch(training)
     # pruning and scores from g1
     valid_mask, go1_pack = self._get_g1_pack(annotated_insts,
                                              self.lambda_g1_arc_training,
                                              self.lambda_g1_lab_training)
     # encode
     input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
         annotated_insts, training)
     mask_expr = BK.input_real(mask_arr)
     # the parsing loss
     final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
     parsing_loss, parsing_scores, info = \
         self.dl.loss(annotated_insts, enc_repr, final_valid_expr, go1_pack, True, self.margin.value)
     info["loss_parse"] = BK.get_value(parsing_loss).item()
     final_loss = parsing_loss
     # other loss?
     jpos_loss = self.jpos_loss(jpos_pack, mask_expr)
     if jpos_loss is not None:
         info["loss_jpos"] = BK.get_value(jpos_loss).item()
         final_loss = parsing_loss + jpos_loss
     if parsing_scores is not None:
         reg_loss = self.reg_scores_loss(*parsing_scores)
         if reg_loss is not None:
             final_loss = final_loss + reg_loss
     info["fb"] = 1
     if training:
         BK.backward(final_loss, loss_factor)
     return info
Example #17
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     iconf = self.conf.iconf
     pconf = iconf.pruning_conf
     with BK.no_grad_env():
         self.refresh_batch(False)
         if iconf.use_pruning:
             # todo(note): for the testing of pruning mode, use the scores instead
             if self.g1_use_aux_scores:
                 valid_mask, arc_score, label_score, mask_expr, _ = G1Parser.score_and_prune(
                     insts, self.num_label, pconf)
             else:
                 valid_mask, arc_score, label_score, mask_expr, _ = self.prune_on_batch(
                     insts, pconf)
             valid_mask_f = valid_mask.float()  # [*, len, len]
             mask_value = Constants.REAL_PRAC_MIN
             full_score = arc_score.unsqueeze(-1) + label_score
             full_score += (mask_value * (1. - valid_mask_f)).unsqueeze(-1)
             info_pruning = G1Parser.collect_pruning_info(
                 insts, valid_mask_f)
             jpos_pack = [None, None, None]
         else:
             input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
                 insts, False)
             mask_expr = BK.input_real(mask_arr)
             full_score = self.scorer_helper.score_full(enc_repr)
             info_pruning = None
         # =====
         self._decode(insts, full_score, mask_expr, "g1")
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         if info_pruning is not None:
             info.update(info_pruning)
         return info
Example #18
0
 def run(self, insts: List[DocInstance]):
     conf = self.conf
     # 1. build the query: DocEncoding
     cls_list = []  # [NumDoc, NumSent]
     for one_doc in insts:
         cls_list.append([
             z.extra_features["aux_repr"][0] for z in one_doc.sents
         ])  # [0] as [CLS]
     # padding
     pad_repr = np.zeros(self.input_dim)
     cls_repr_arr, cls_mask_arr = self._pad_3d_arr(
         cls_list, pad_repr)  # [NumDoc, MaxDLen, D]
     enc_doc_repr = self.enc_doc(BK.input_real(cls_repr_arr),
                                 cls_mask_arr)  # [NumDoc, MaxDLen, D']
     # 2. build the doc hints
     all_keyword_reprs, all_keysent_reprs = [], []  # [NumDoc, ?, D]
     for one_doc in insts:
         one_keyword_reprs, one_keysent_reprs = self._get_keyws(one_doc)
         all_keyword_reprs.append(one_keyword_reprs)
         all_keysent_reprs.append(one_keysent_reprs)
     # 2.5 attention
     final_inputs = [enc_doc_repr]
     if conf.use_keyword:
         keyword_repr_arr, keyword_mask_arr = self._pad_3d_arr(
             all_keyword_reprs, pad_repr)  # [NumDoc, Kw, D]
         keyword_repr_t, keyword_mask_t = BK.input_real(
             keyword_repr_arr), BK.input_real(keyword_mask_arr)
         att_kw_repr = self.kw_att(
             keyword_repr_t,
             keyword_repr_t,
             enc_doc_repr,
             mask_k=keyword_mask_t)  # [NumDoc, MaxDLen, D]
         final_inputs.append(att_kw_repr)
     if conf.use_keysent:
         keysent_repr_arr, keysent_mask_arr = self._pad_3d_arr(
             all_keysent_reprs, pad_repr)  # [NumDoc, Ks, D]
         keysent_repr_t, keysent_mask_t = BK.input_real(
             keysent_repr_arr), BK.input_real(keysent_mask_arr)
         att_ks_repr = self.ks_att(
             keysent_repr_t,
             keysent_repr_t,
             enc_doc_repr,
             mask_k=keysent_mask_t)  # [NumDoc, MaxDLen, D]
         final_inputs.append(att_ks_repr)
     # 3. combine all and return
     final_repr = self.final_layer(final_inputs)  # [NumDoc, MaxDLen, Dout]
     return final_repr
Example #19
0
 def batch_inputs_g1(self, insts: List[Sentence]):
     train_reverse_evetns = self.conf.train_reverse_evetns  # todo(note): this option is from derived class
     _tmp_f = lambda x: list(reversed(x)
                             ) if train_reverse_evetns else lambda x: x
     key, items_getter = self.extract_type, self.items_getter
     # nil_idx = 0  # nil means eos
     # get gold/input data and batch
     all_widxes, all_lidxes, all_vmasks, all_items, all_valid = [], [], [], [], []
     for sent in insts:
         preps = sent.preps.get(key)
         # not cached, rebuild them
         if preps is None:
             items = items_getter(sent)
             # todo(note): directly add, assume they are already sorted in a good way (widx+lidx); 0(nil) as eos
             if items is None:
                 prep_valid = 0.
                 # prep_widxes, prep_lidxes, prep_vmasks, prep_items = [0], [0], [1.], [None]
                 prep_widxes, prep_lidxes, prep_vmasks, prep_items = [], [], [], []
             else:
                 prep_valid = 1.
                 prep_widxes = _tmp_f(
                     [z.mention.hard_span.head_wid for z in items]) + [0]
                 prep_lidxes = _tmp_f(
                     [self.hlidx2idx(z.type_idx) for z in items]) + [0]
                 prep_vmasks = [1.] * (len(items) + 1)
                 prep_items = _tmp_f(items.copy()) + [None]
             sent.preps[key] = (prep_widxes, prep_lidxes, prep_vmasks,
                                prep_items, prep_valid)
         else:
             prep_widxes, prep_lidxes, prep_vmasks, prep_items, prep_valid = preps
         # =====
         all_widxes.append(prep_widxes)
         all_lidxes.append(prep_lidxes)
         all_vmasks.append(prep_vmasks)
         all_items.append(prep_items)
         all_valid.append(prep_valid)
     # pad and batch
     mention_widxes = BK.input_idx(
         self.padder_idxes.pad(all_widxes)[0])  # [*, ?]
     mention_lidxes = BK.input_idx(
         self.padder_idxes.pad(all_lidxes)[0])  # [*, ?]
     mention_vmasks = BK.input_real(
         self.padder_mask.pad(all_vmasks)[0])  # [*, ?]
     mention_items_arr, _ = self.padder_items.pad(all_items)  # [*, ?]
     mention_valid = BK.input_real(all_valid)  # [*]
     return mention_widxes, mention_lidxes, mention_vmasks, mention_items_arr, mention_valid
Example #20
0
 def _init_fixed_mask(self, enc_mask_arr):
     tmp_device = BK.CPU_DEVICE
     # by token mask
     mask_ct = BK.input_real(enc_mask_arr, device=tmp_device)  # [*, len]
     full_mask_ct = mask_ct.unsqueeze(-1) * mask_ct.unsqueeze(
         -2)  # [*, len-mod, len-head]
     # no self loop
     full_mask_ct *= (1. - BK.eye(self.max_slen, device=tmp_device))
     # no root as mod; todo(warn): assume it is 3D
     full_mask_ct[:, 0, :] = 0.
     return full_mask_ct
Example #21
0
 def fb_on_batch(self, annotated_insts: List[DocInstance], training=True, loss_factor=1., **kwargs):
     self.refresh_batch(training)
     self.evt_extractor.set_constrain_evt_types(self.train_constrain_evt_types)  # also ignore irrelevant types for training
     ndoc, nsent = len(annotated_insts), 0
     margin = self.margin.value
     lambda_ef, lambda_evt, lambda_arg = self.lambda_ef.value, self.lambda_evt.value, self.lambda_arg.value
     lookup_ef, lookup_evt = self.conf.tconf.lookup_ef, self.conf.tconf.lookup_evt
     #
     has_loss_ef = has_loss_evt = has_loss_arg = lambda_arg > 0.
     has_loss_ef = has_loss_ef or (lambda_ef > 0.)
     has_loss_evt = has_loss_evt or (lambda_evt > 0.)
     # splitting into buckets
     all_packs = self.bter.run(annotated_insts, training=training)
     all_ef_losses = []
     all_evt_losses = []
     all_arg_losses = []
     for one_pack in all_packs:
         # =====
         # predict
         sent_insts, lexi_repr, enc_repr_ef, enc_repr_evt, mask_arr = one_pack
         nsent += len(sent_insts)
         mask_expr = BK.input_real(mask_arr)
         # entity and filler
         if lookup_ef:
             ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \
                 self._lookup_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, self.ef_extractor)
         elif has_loss_ef:
             ef_losses, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \
                 self._fb_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, self.ef_extractor, margin)
             all_ef_losses.append(ef_losses)
         # event
         if lookup_evt:
             evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \
                 self._lookup_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, self.evt_extractor)
         elif has_loss_evt:
             evt_losses, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \
                 self._fb_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, self.evt_extractor, margin)
             all_evt_losses.append(evt_losses)
         # arg
         if has_loss_arg:
             # todo(note): for training, we only consider inner-sentence pairs,
             #  since most of the training data is like this
             arg_losses = self._fb_args(ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef,
                                        evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt, margin)
             all_arg_losses.append(arg_losses)
     # =====
     # final loss sum and backward
     info = {"doc": ndoc, "sent": nsent, "fb": 1}
     if len(all_packs) == 0:
         return info
     self.collect_loss_and_backward(["ef", "evt", "arg"], [all_ef_losses, all_evt_losses, all_arg_losses],
                                    [lambda_ef, lambda_evt, lambda_arg], info, training, loss_factor)
     return info
Example #22
0
 def prune_on_batch(self, insts: List[ParseInstance], pconf: PruneG1Conf):
     with BK.no_grad_env():
         self.refresh_batch(False)
         # encode
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
             insts, False)
         mask_expr = BK.input_real(mask_arr)
         arc_score = self.scorer_helper.score_arc(enc_repr)
         label_score = self.scorer_helper.score_label(enc_repr)
         final_valid_mask, arc_marginals = G1Parser.prune_with_scores(
             arc_score, label_score, mask_expr, pconf)
         return final_valid_mask, arc_score.squeeze(
             -1), label_score, mask_expr, arc_marginals
Example #23
0
 def collect_aux_scores(insts: List[ParseInstance], output_num_label):
     score_tuples = [z.extra_features["aux_score"] for z in insts]
     num_label = score_tuples[0][1].shape[-1]
     max_len = max(len(z) + 1 for z in insts)
     mask_value = Constants.REAL_PRAC_MIN
     bsize = len(insts)
     arc_score_arr = np.full([bsize, max_len, max_len],
                             mask_value,
                             dtype=np.float32)
     lab_score_arr = np.full([bsize, max_len, max_len, output_num_label],
                             mask_value,
                             dtype=np.float32)
     mask_arr = np.full([bsize, max_len], 0., dtype=np.float32)
     for bidx, one_tuple in enumerate(score_tuples):
         one_score_arc, one_score_lab = one_tuple
         one_len = one_score_arc.shape[1]
         arc_score_arr[bidx, :one_len, :one_len] = one_score_arc
         lab_score_arr[bidx, :one_len, :one_len,
                       -num_label:] = one_score_lab
         mask_arr[bidx, :one_len] = 1.
     return BK.input_real(arc_score_arr).unsqueeze(-1), BK.input_real(
         lab_score_arr), BK.input_real(mask_arr)
Example #24
0
 def pad_chs(self, idxes_list: List[List], labels_list: List[List]):
     start_posi = self.chs_start_posi
     if start_posi < 0:  # truncate
         idxes_list = [x[start_posi:] for x in idxes_list]
     # overall valid mask
     chs_valid = [(0. if len(z) == 0 else 1.) for z in idxes_list]
     # if any valid children in the batch
     if all(x > 0 for x in chs_valid):
         padded_chs_idxes, padded_chs_mask = self.ch_idx_padder.pad(
             idxes_list)  # [*, max-ch], [*, max-ch]
         if self.use_label_feat:
             if start_posi < 0:  # truncate
                 labels_list = [x[start_posi:] for x in labels_list]
             padded_chs_labels, _ = self.ch_label_padder.pad(
                 labels_list)  # [*, max-ch]
             chs_label_t = BK.input_idx(padded_chs_labels)
         else:
             chs_label_t = None
         chs_idxes_t, chs_mask_t, chs_valid_mask_t = \
             BK.input_idx(padded_chs_idxes), BK.input_real(padded_chs_mask), BK.input_real(chs_valid)
         return chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t
     else:
         return None, None, None, None
Example #25
0
 def refresh(self, rop=None):
     super().refresh(rop)
     # no need to fix0 for None since already done in the Embedding
     # refresh layered embeddings (in training, we should not be in no-grad mode)
     # todo(note): here, there can be dropouts
     layered_prei_arrs = self.hl_vocab.layered_prei
     layered_pool_links_padded_arrs = self.hl_vocab.layered_pool_links_padded
     layered_pool_links_mask_arrs = self.hl_vocab.layered_pool_links_mask
     layered_isnil = self.hl_vocab.layered_pool_isnil
     for i in range(self.max_layer):
         # [N, ?, D] -> [N, D] -> [D, N]
         self.layered_embeds_pred[i] = (
             BK.input_real(layered_pool_links_mask_arrs[i]).unsqueeze(-1) *
             self.pool_pred(
                 layered_pool_links_padded_arrs[i])).sum(-2).transpose(
                     0, 1).contiguous()
         # [N, ?, D] -> [N, D]
         self.layered_embeds_lookup[i] = (
             BK.input_real(layered_pool_links_mask_arrs[i]).unsqueeze(-1) *
             self.pool_lookup(layered_pool_links_padded_arrs[i])).sum(-2)
         # [?] of idxes/masks
         self.layered_prei[i] = BK.input_idx(layered_prei_arrs[i])
         self.layered_isnil[i] = BK.input_real(
             layered_isnil[i])  # is nil mask
Example #26
0
def nmarginal_proj(scores_expr, mask_expr, lengths_arr, labeled=True):
    assert labeled
    with BK.no_grad_env():
        # first make it unlabeled by sum-exp
        scores_unlabeled = BK.logsumexp(scores_expr, dim=-1)  # [BS, m, h]
        # marginal for unlabeled
        scores_unlabeled_arr = BK.get_value(scores_unlabeled)
        marginals_unlabeled_arr = marginal_proj(scores_unlabeled_arr,
                                                lengths_arr, False)
        # back to labeled values
        marginals_unlabeled_expr = BK.input_real(marginals_unlabeled_arr)
        marginals_labeled_expr = marginals_unlabeled_expr.unsqueeze(
            -1) * BK.exp(scores_expr - scores_unlabeled.unsqueeze(-1))
        # [BS, m, h, L]
        return _ensure_margins_norm(marginals_labeled_expr)
Example #27
0
 def inference_on_batch(self, insts: List[ParseInstance], **kwargs):
     # iconf = self.conf.iconf
     with BK.no_grad_env():
         self.refresh_batch(False)
         # pruning and scores from g1
         valid_mask, go1_pack = self._get_g1_pack(
             insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing)
         # encode
         input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(
             insts, False)
         mask_expr = BK.input_real(mask_arr)
         # decode
         final_valid_expr = self._make_final_valid(valid_mask, mask_expr)
         ret_heads, ret_labels, _, _ = self.dl.decode(
             insts, enc_repr, final_valid_expr, go1_pack, False, 0.)
         # collect the results together
         all_heads = Helper.join_list(ret_heads)
         if ret_labels is None:
             # todo(note): simply get labels from the go1-label classifier; must provide g1parser
             if go1_pack is None:
                 _, go1_pack = self._get_g1_pack(insts, 1., 1.)
             _, go1_label_max_idxes = go1_pack[1].max(
                 -1)  # [bs, slen, slen]
             pred_heads_arr, _ = self.predict_padder.pad(
                 all_heads)  # [bs, slen]
             pred_heads_expr = BK.input_idx(pred_heads_arr)
             pred_labels_expr = BK.gather_one_lastdim(
                 go1_label_max_idxes, pred_heads_expr).squeeze(-1)
             all_labels = BK.get_value(pred_labels_expr)  # [bs, slen]
         else:
             all_labels = np.concatenate(ret_labels, 0)
         # ===== assign, todo(warn): here, the labels are directly original idx, no need to change
         for one_idx, one_inst in enumerate(insts):
             cur_length = len(one_inst) + 1
             one_inst.pred_heads.set_vals(
                 all_heads[one_idx]
                 [:cur_length])  # directly int-val for heads
             one_inst.pred_labels.build_vals(
                 all_labels[one_idx][:cur_length], self.label_vocab)
             # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length])
         # =====
         # put jpos result (possibly)
         self.jpos_decode(insts, jpos_pack)
         # -----
         info = {"sent": len(insts), "tok": sum(map(len, insts))}
         return info
Example #28
0
def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr):
    assert labeled
    with BK.no_grad_env():
        # argmax-label: [BS, m, h]
        scores_unlabeled_max, labels_argmax = scores_expr.max(-1)
        #
        scores_unlabeled_max_arr = BK.get_value(scores_unlabeled_max)
        mst_heads_arr, _, mst_scores_arr = CPU_f(scores_unlabeled_max_arr,
                                                 lengths_arr,
                                                 labeled=False)
        # [BS, m]
        mst_heads_expr = BK.input_idx(mst_heads_arr)
        mst_labels_expr = BK.gather_one_lastdim(labels_argmax,
                                                mst_heads_expr).squeeze(-1)
        # prepare for the outputs
        if ret_arr:
            return mst_heads_arr, BK.get_value(mst_labels_expr), mst_scores_arr
        else:
            return mst_heads_expr, mst_labels_expr, BK.input_real(
                mst_scores_arr)
Example #29
0
 def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf,
              conf: EmbedderNodeConf, vpack: VocabPackage):
     super().__init__(pc, comp_name, ec_conf, conf, vpack)
     # -----
     # get embeddings
     npvec = None
     if self.ec_conf.comp_init_from_pretrain:
         npvec = vpack.get_emb(comp_name)
         zlog(f"Try to init InputEmbedNode {comp_name} with npvec.shape={npvec.shape if (npvec is not None) else None}")
         if npvec is None:
             zwarn("Warn: cannot get pre-trained embeddings to init!!")
     # get rare unk range
     # - get freq vals, make sure special ones will not be pruned; todo(note): directly use that field
     voc_rare_mask = [float(z is not None and z<=ec_conf.comp_rare_thr) for z in self.voc.final_vals]
     self.rare_mask = BK.input_real(voc_rare_mask)
     self.use_rare_unk = (ec_conf.comp_rare_unk>0. and ec_conf.comp_rare_thr>0)
     # --
     # dropout outside explicitly
     self.E = self.add_sub_node(f"E{self.comp_name}", Embedding(
         pc, len(self.voc), self.comp_dim, fix_row0=conf.embed_fix_row0, npvec=npvec, name=comp_name,
         init_rop=NoDropRop(), init_scale=self.comp_init_scale))
     self.create_dropout_node()
Example #30
0
 def __call__(self, input_map: Dict):
     exprs = []
     # get masks: this mask is for validing of inst batching
     final_masks = BK.input_real(input_map["mask"])  # [*, slen]
     if self.add_root_token:  # append 1
         slice_t = BK.constants(BK.get_shape(final_masks)[:-1]+[1], 1.)
         final_masks = BK.concat([slice_t, final_masks], -1)  # [*, 1+slen]
     # -----
     # for each component
     for idx, name in enumerate(self.comp_names):
         cur_node = self.nodes[idx]
         cur_input = input_map[name]
         cur_expr = cur_node(cur_input, self.add_root_token)
         exprs.append(cur_expr)
     # -----
     concated_exprs = BK.concat(exprs, dim=-1)
     # optional proj
     if self.has_proj:
         final_expr = self.final_layer(concated_exprs)
     else:
         final_expr = concated_exprs
     return final_expr, final_masks