def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( annotated_insts, training) mask_expr = BK.input_real(mask_arr) # the parsing loss arc_score = self.scorer_helper.score_arc(enc_repr) lab_score = self.scorer_helper.score_label(enc_repr) full_score = arc_score + lab_score parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr) # other loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) reg_loss = self.reg_scores_loss(arc_score, lab_score) # info["loss_parse"] = BK.get_value(parsing_loss).item() final_loss = parsing_loss if jpos_loss is not None: info["loss_jpos"] = BK.get_value(jpos_loss).item() final_loss = parsing_loss + jpos_loss if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training: BK.backward(final_loss, loss_factor) return info
def get_losses_from_attn_list(list_attn_info: List, ts_f, loss_f, loss_prefix, loss_lambda): loss_num = None loss_counts: List[int] = [] loss_sums: List[List] = [] rets = [] # ----- for one_attn_info in list_attn_info: # each update step one_ts: List = ts_f( one_attn_info) # get tensor list from attn_info # get number of losses if loss_num is None: loss_num = len(one_ts) loss_counts = [0] * loss_num loss_sums = [[] for _ in range(loss_num)] else: assert len(one_ts) == loss_num, "mismatched ts length" # iter them for one_t_idx, one_t in enumerate( one_ts): # iter on the tensor list one_loss = loss_f(one_t) # need it to be in the corresponding shape loss_counts[one_t_idx] += np.prod( BK.get_shape(one_loss)).item() loss_sums[one_t_idx].append(one_loss.sum()) # for different steps for i, one_loss_count, one_loss_sums in zip(range(len(loss_counts)), loss_counts, loss_sums): loss_leaf = LossHelper.compile_leaf_info( f"{loss_prefix}{i}", BK.stack(one_loss_sums, 0).sum(), BK.input_real(one_loss_count), loss_lambda=loss_lambda) rets.append(loss_leaf) return rets
def inference_on_batch(self, insts: List[ParseInstance], **kwargs): iconf = self.conf.iconf pconf = iconf.pruning_conf with BK.no_grad_env(): self.refresh_batch(False) if iconf.use_pruning: # todo(note): for the testing of pruning mode, use the scores instead if self.g1_use_aux_scores: valid_mask, arc_score, label_score, mask_expr, _ = G1Parser.score_and_prune( insts, self.num_label, pconf) else: valid_mask, arc_score, label_score, mask_expr, _ = self.prune_on_batch( insts, pconf) valid_mask_f = valid_mask.float() # [*, len, len] mask_value = Constants.REAL_PRAC_MIN full_score = arc_score.unsqueeze(-1) + label_score full_score += (mask_value * (1. - valid_mask_f)).unsqueeze(-1) info_pruning = G1Parser.collect_pruning_info( insts, valid_mask_f) jpos_pack = [None, None, None] else: input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( insts, False) mask_expr = BK.input_real(mask_arr) full_score = self.scorer_helper.score_full(enc_repr) info_pruning = None # ===== self._decode(insts, full_score, mask_expr, "g1") # put jpos result (possibly) self.jpos_decode(insts, jpos_pack) # ----- info = {"sent": len(insts), "tok": sum(map(len, insts))} if info_pruning is not None: info.update(info_pruning) return info
def loss(self, ms_items: List, bert_expr, basic_expr): conf = self.conf bsize = len(ms_items) # use gold targets: only use positive samples!! offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.events, True, False, 0., 0., True) # [bs, ?] realis_flist = [(-1 if (z is None or z.realis_idx is None) else z.realis_idx) for z in items_arr.flatten()] realis_t = BK.input_idx(realis_flist).view(items_arr.shape) # [bs, ?] realis_mask = (realis_t >= 0).float() realis_t.clamp_(min=0) # make sure all idxes are legal # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] # realis, types # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build losses loss_item_realis = self._get_one_loss(self.realis_predictor, hiddens, realis_t, realis_mask, conf.lambda_realis) loss_item_type = self._get_one_loss(self.type_predictor, hiddens, labels_t, masks_t, conf.lambda_type) return [loss_item_realis, loss_item_type]
def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # todo(note): here always using training lambdas full_score, original_scores, jpos_pack, mask_expr, valid_mask_d, _ = \ self._score(annotated_insts, False, self.lambda_g1_arc_training, self.lambda_g1_lab_training) parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr, valid_mask_d) # other loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) reg_loss = self.reg_scores_loss(*original_scores) # info["loss_parse"] = BK.get_value(parsing_loss).item() final_loss = parsing_loss if jpos_loss is not None: info["loss_jpos"] = BK.get_value(jpos_loss).item() final_loss = parsing_loss + jpos_loss if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training: BK.backward(final_loss, loss_factor) return info
def main(args): conf, model, vpack, test_iter = prepare_test(args) dconf = conf.dconf # todo(note): here is the main change # make sure the model is order 1 graph model, otherwise cannot run through all_results = [] all_insts = [] with utils.Timer(tag="Run-score", info="", print_date=True): for cur_insts in test_iter: all_insts.extend(cur_insts) batched_arc_scores, batched_label_scores = model.score_on_batch( cur_insts) batched_arc_scores, batched_label_scores = BK.get_value( batched_arc_scores), BK.get_value(batched_label_scores) for cur_idx in range(len(cur_insts)): cur_len = len(cur_insts[cur_idx]) + 1 # discarding paddings cur_res = (batched_arc_scores[cur_idx, :cur_len, :cur_len], batched_label_scores[cur_idx, :cur_len, :cur_len]) all_results.append(cur_res) # reorder to the original order orig_indexes = [z.inst_idx for z in all_insts] orig_results = [None] * len(orig_indexes) for new_idx, orig_idx in enumerate(orig_indexes): assert orig_results[orig_idx] is None orig_results[orig_idx] = all_results[new_idx] # saving with utils.Timer(tag="Run-write", info=f"Writing to {dconf.output_file}", print_date=True): import pickle with utils.zopen(dconf.output_file, "wb") as fd: for one in orig_results: pickle.dump(one, fd) utils.printing("The end.")
def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.): conf = self.conf bsize = len(ms_items) # build targets (include all sents) # todo(note): use "x.entity_fillers" for getting gold args offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.entity_fillers, True, True, conf.train_neg_rate, conf.train_neg_rate_outside, True) labels_t.clamp_(max=1) # either 0 or 1 # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz]] # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build loss logits = self.predictor(hiddens) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze( -1) # [bsize, ?] masked_losses = picked_log_probs * masks_t # loss_sum, loss_count, gold_count return [[ masked_losses.sum(), masks_t.sum(), (labels_t > 0).float().sum() ]]
def _fb_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt, margin): # get the gold idxes arg_linker = self.arg_linker bsize, len_ef = ef_items.shape bsize2, len_evt = evt_items.shape assert bsize == bsize2 gold_idxes = np.zeros([bsize, len_ef, len_evt], dtype=np.long) for one_gold_idxes, one_ef_items, one_evt_items in zip(gold_idxes, ef_items, evt_items): # todo(note): check each pair for ef_idx, one_ef in enumerate(one_ef_items): if one_ef is None: continue role_map = {id(z.evt): z.role_idx for z in one_ef.links} # todo(note): since we get the original linked ones for evt_idx, one_evt in enumerate(one_evt_items): pairwise_role_hlidx = role_map.get(id(one_evt)) if pairwise_role_hlidx is not None: pairwise_role_idx = arg_linker.hlidx2idx(pairwise_role_hlidx) assert pairwise_role_idx > 0 one_gold_idxes[ef_idx, evt_idx] = pairwise_role_idx # get loss repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2) # [*, len-ef, D] repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2) # [*, len-evt, D] if np.prod(gold_idxes.shape) == 0: # no instances! return [[BK.zeros([]), BK.zeros([])]] else: gold_idxes_t = BK.input_idx(gold_idxes) return arg_linker.loss(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask, gold_idxes_t, margin)
def calculate_repr(self, cur_t, par_t, label_t, par_mask_t, chs_t, chs_label_t, chs_mask_t, chs_valid_mask_t): ret_t = cur_t # [*, D] # padding 0 if not using labels dim_label = self.dim_label # child features if self.use_chs and chs_t is not None: if self.use_label_feat: chs_label_rt = self.label_embeddings( chs_label_t) # [*, max-chs, dlab] else: labels_shape = BK.get_shape(chs_t) labels_shape[-1] = dim_label chs_label_rt = BK.zeros(labels_shape) chs_input_t = BK.concat([chs_t, chs_label_rt], -1) chs_feat0 = self.chs_reprer(cur_t, chs_input_t, chs_mask_t, chs_valid_mask_t) chs_feat = self.chs_ff(chs_feat0) ret_t += chs_feat # parent features if self.use_par and par_t is not None: if self.use_label_feat: cur_label_t = self.label_embeddings(label_t) # [*, dlab] else: labels_shape = BK.get_shape(par_t) labels_shape[-1] = dim_label cur_label_t = BK.zeros(labels_shape) par_feat = self.par_ff([par_t, cur_label_t]) if par_mask_t is not None: par_feat *= par_mask_t.unsqueeze(-1) ret_t += par_feat return ret_t
def __call__(self, scores, temperature=1., dim=-1): is_training = self.rop.training # only use stochastic at training if is_training: if self.use_gumbel: gumbel_eps = self.gumbel_eps G = (BK.rand(BK.get_shape(scores)) + gumbel_eps).clamp( max=1.) # [0,1) scores = scores - (gumbel_eps - G.log()).log() # normalize probs = BK.softmax(scores / temperature, dim=dim) # [*, S] # prune and re-normalize? if self.prune_val > 0.: probs = probs * (probs > self.prune_val).float() # todo(note): currently no re-normalize # probs = probs / probs.sum(dim=dim, keepdim=True) # [*, S] # argmax and ste if self.use_argmax: # use the hard argmax max_probs, _ = probs.max(dim, keepdim=True) # [*, 1] # todo(+N): currently we do not re-normalize here, should it be done here? st_probs = (probs >= max_probs).float() * probs # [*, S] if is_training: # (hard-soft).detach() + soft st_probs = (st_probs - probs).detach() + probs # [*, S] return st_probs else: return probs
def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None): conf = self.conf # ----- # special mode if conf.aug_word2 and conf.aug_word2_aug_encoder: _rop = RefreshOptions(training=False) # special feature-mode!! self.embedder.refresh(_rop) self.encoder.refresh(_rop) # ----- emb_t, mask_t = self.embedder(cur_input_map) rel_dist = cur_input_map.get("rel_dist", None) if rel_dist is not None: rel_dist = BK.input_idx(rel_dist) if conf.enc_choice == "vrec": enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) elif conf.enc_choice == "original": # todo(note): change back to arr for back compatibility assert rel_dist is None, "Original encoder does not support rel_dist" enc_t = self.encoder(emb_t, BK.get_value(mask_t)) cache, enc_loss = None, None else: raise NotImplementedError() # another encoder based on attn final_enc_t = self.rpreper(emb_t, enc_t, cache) # [*, slen, D] => final encoder output if conf.aug_word2: emb2_t = self.aug_word2(insts) if conf.aug_word2_aug_encoder: # simply add them all together, detach orig-enc as features stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach() features = self.aug_mixturer(stack_hidden_t) aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features)) final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) else: final_enc_t = (final_enc_t + emb2_t) # otherwise, simply adding return emb_t, mask_t, final_enc_t, cache, enc_loss
def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): self.refresh_batch(training) # pruning and scores from g1 valid_mask, go1_pack = self._get_g1_pack(annotated_insts, self.lambda_g1_arc_training, self.lambda_g1_lab_training) # encode input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run( annotated_insts, training) mask_expr = BK.input_real(mask_arr) # the parsing loss final_valid_expr = self._make_final_valid(valid_mask, mask_expr) parsing_loss, parsing_scores, info = \ self.dl.loss(annotated_insts, enc_repr, final_valid_expr, go1_pack, True, self.margin.value) info["loss_parse"] = BK.get_value(parsing_loss).item() final_loss = parsing_loss # other loss? jpos_loss = self.jpos_loss(jpos_pack, mask_expr) if jpos_loss is not None: info["loss_jpos"] = BK.get_value(jpos_loss).item() final_loss = parsing_loss + jpos_loss if parsing_scores is not None: reg_loss = self.reg_scores_loss(*parsing_scores) if reg_loss is not None: final_loss = final_loss + reg_loss info["fb"] = 1 if training: BK.backward(final_loss, loss_factor) return info
def _get_basic_score(self, mb_enc_expr, batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes): allp_size = BK.get_shape(batch_idxes, 0) all_arc_scores, all_lab_scores = [], [] cur_pidx = 0 while cur_pidx < allp_size: next_pidx = min(allp_size, cur_pidx + self.mb_dec_sb) # first calculate srepr s_enc = self.slayer cur_batch_idxes = batch_idxes[cur_pidx:next_pidx] h_expr = mb_enc_expr[cur_batch_idxes, h_idxes[cur_pidx:next_pidx]] m_expr = mb_enc_expr[cur_batch_idxes, m_idxes[cur_pidx:next_pidx]] s_expr = mb_enc_expr[cur_batch_idxes, sib_idxes[cur_pidx:next_pidx]].unsqueeze(-2) \ if (sib_idxes is not None) else None # [*, 1, D] g_expr = mb_enc_expr[cur_batch_idxes, gp_idxes[cur_pidx:next_pidx]] if ( gp_idxes is not None) else None head_srepr = s_enc.calculate_repr(h_expr, g_expr, None, None, s_expr, None, None, None) mod_srepr = s_enc.forward_repr(m_expr) # then get the scores arc_score = self.scorer.transform_and_arc_score_plain( mod_srepr, head_srepr).squeeze(-1) all_arc_scores.append(arc_score) if self.system_labeled: lab_score = self.scorer.transform_and_label_score_plain( mod_srepr, head_srepr) all_lab_scores.append(lab_score) cur_pidx = next_pidx final_arc_score = BK.concat(all_arc_scores, 0) final_lab_score = BK.concat(all_lab_scores, 0) if self.system_labeled else None return final_arc_score, final_lab_score
def _select_topk(self, masked_scores, pad_mask, ratio_mask, topk_ratio, thresh_k): slen = BK.get_shape(masked_scores, -1) sel_mask = BK.copy(pad_mask) # first apply the absolute thresh if thresh_k is not None: sel_mask *= (masked_scores > thresh_k).float() # then ratio-ed topk if topk_ratio > 0.: # prepare number cur_topk_num = ratio_mask.sum(-1) # [*] cur_topk_num = (cur_topk_num * topk_ratio).long() # [*] cur_topk_num.clamp_(min=1, max=slen) # at least one, at most all # topk actual_max_k = max(cur_topk_num.max().item(), 1) topk_score, _ = BK.topk(masked_scores, actual_max_k, dim=-1, sorted=True) # [*, k] thresh_score = topk_score.gather( -1, cur_topk_num.clamp(min=1).unsqueeze(-1) - 1) # [*, 1] # get mask and apply sel_mask *= (masked_scores >= thresh_score).float() return sel_mask
def get_selected_label_scores(self, idxes_m_t, idxes_h_t, bsize_range_t, oracle_mask_t, oracle_label_t, arc_margin: float, label_margin: float): # todo(note): in this mode, no repeated arc_margin dim1_range_t = bsize_range_t dim2_range_t = dim1_range_t.unsqueeze(-1) if self.system_labeled: selected_m_cache = [ z[dim2_range_t, idxes_m_t] for z in self.mod_label_cache ] selected_h_repr = self.head_label_cache[dim2_range_t, idxes_h_t] ret = self.scorer.score_label(selected_m_cache, selected_h_repr) # [*, k, labels] if label_margin > 0.: oracle_label_idxes = oracle_label_t[dim2_range_t, idxes_m_t, idxes_h_t].unsqueeze( -1) # [*, k, 1] of int ret.scatter_add_( -1, oracle_label_idxes, BK.constants(oracle_label_idxes.shape, -label_margin)) else: # todo(note): otherwise, simply put zeros (with idx=0 as the slightly best to be consistent) ret = BK.zeros(BK.get_shape(idxes_m_t) + [self.num_label]) ret[:, :, 0] += 0.01 if self.g1_lab_scores is not None: ret += self.g1_lab_scores[dim2_range_t, idxes_m_t, idxes_h_t] return ret
def _loss(self, enc_repr, action_list: List[EfAction], arc_weight_list: List[float], label_weight_list: List[float], bidxes_list: List[int]): # 1. collect (batched) features; todo(note): use prev state for scoring hm_features = self.hm_feature_getter.get_hm_features(action_list, [a.state_from for a in action_list]) # 2. get new sreprs scorer = self.scorer s_enc = self.slayer bsize_range_t = BK.input_idx(bidxes_list) node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], enc_repr, bsize_range_t) node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], enc_repr, bsize_range_t) # label loss if self.system_labeled: node_lh_expr, _ = scorer.transform_space_label(node_h_srepr, True, False) _, node_lm_pack = scorer.transform_space_label(node_m_srepr, False, True) label_scores_full = scorer.score_label(node_lm_pack, node_lh_expr) # [*, Lab] label_scores = BK.gather_one_lastdim(label_scores_full, [a.label for a in action_list]).squeeze(-1) final_label_loss_sum = (label_scores * BK.input_real(label_weight_list)).sum() else: label_scores = final_label_loss_sum = BK.zeros([]) # arc loss node_ah_expr, _ = scorer.transform_space_arc(node_h_srepr, True, False) _, node_am_pack = scorer.transform_space_arc(node_m_srepr, False, True) arc_scores = scorer.score_arc(node_am_pack, node_ah_expr).squeeze(-1) final_arc_loss_sum = (arc_scores * BK.input_real(arc_weight_list)).sum() # score reg return final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores
def predict(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): conf = self.conf # scoring arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] full_score = BK.log_softmax(arc_score, -2) + BK.log_softmax( lab_score, -1) # [bs, m, h, Lab] # decode mst_lengths = [len(z) + 1 for z in insts] # +1 to include ROOT for mst decoding mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) mst_heads_arr, mst_labels_arr, mst_scores_arr = \ nmst_unproj(full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True) # ===== assign, todo(warn): here, the labels are directly original idx, no need to change misc_prefix = "g" for one_idx, one_inst in enumerate(insts): cur_length = mst_lengths[one_idx] one_inst.pred_heads.set_vals( mst_heads_arr[one_idx] [:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals( mst_labels_arr[one_idx][:cur_length], self.label_vocab) one_scores = mst_scores_arr[one_idx][:cur_length] one_inst.pred_par_scores.set_vals(one_scores) # extra output one_inst.extra_pred_misc[misc_prefix + "_score"] = one_scores.tolist()
def _inference_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt): arg_linker = self.arg_linker repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2) # [*, len-ef, D] repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2) # [*, len-evt, D] role_logprobs, role_predictions = arg_linker.predict(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask) # add them inplaced roles_arr = BK.get_value(role_predictions) # [*, len-ef, len-evt] logprobs_arr = BK.get_value(role_logprobs) for bidx, one_roles_arr in enumerate(roles_arr): one_ef_items, one_evt_items = ef_items[bidx], evt_items[bidx] # ===== # todo(note): delete origin links! for z in one_ef_items: if z is not None: z.links.clear() for z in one_evt_items: if z is not None: z.links.clear() # ===== one_logprobs = logprobs_arr[bidx] for ef_idx, one_ef in enumerate(one_ef_items): if one_ef is None: continue for evt_idx, one_evt in enumerate(one_evt_items): if one_evt is None: continue one_role_idx = int(one_roles_arr[ef_idx, evt_idx]) if one_role_idx > 0: # link this_hlidx = arg_linker.idx2hlidx(one_role_idx) one_evt.add_arg(one_ef, role=str(this_hlidx), role_idx=this_hlidx, score=float(one_logprobs[ef_idx, evt_idx]))
def main(): pc = BK.ParamCollection() N_BATCH, N_SEQ = 8, 4 N_HIDDEN, N_LAYER = 5, 3 N_INPUT = N_HIDDEN N_FF = 10 # encoders rnn_encoder = layers.RnnLayerBatchFirstWrapper(pc, layers.RnnLayer(pc, N_INPUT, N_HIDDEN, N_LAYER, bidirection=True)) cnn_encoder = layers.Sequential(pc, [layers.CnnLayer(pc, N_INPUT, N_HIDDEN, 3, act="relu") for _ in range(N_LAYER)]) att_encoder = layers.Sequential(pc, [layers.TransformerEncoderLayer(pc, N_INPUT, N_FF) for _ in range(N_LAYER)]) dropout_md = layers.DropoutLastN(pc) # rop = layers.RefreshOptions(hdrop=0.2, gdrop=0.2, dropmd=0.2, fix_drop=True) rnn_encoder.refresh(rop) cnn_encoder.refresh(rop) att_encoder.refresh(rop) dropout_md.refresh(rop) # x = BK.input_real(np.random.randn(N_BATCH, N_SEQ, N_INPUT)) x_mask = np.asarray([[1.]*z+[0.]*(N_SEQ-z) for z in np.random.randint(N_SEQ//2, N_SEQ, N_BATCH)]) y_rnn = rnn_encoder(x, x_mask) y_cnn = cnn_encoder(x, x_mask) y_att = att_encoder(x, x_mask) zz = dropout_md(y_att) print("The end.") pass
def _inference_mentions(self, insts: List[Sentence], lexi_repr, enc_repr, mask_expr, extractor: NodeExtractorBase, item_creator): sel_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds = \ extractor.predict(insts, lexi_repr, enc_repr, mask_expr) # handling outputs here: prepare new items head_idxes_arr = BK.get_value(sel_idxes) # [*, max-count] lab_idxes_arr = BK.get_value(sel_lab_idxes) # [*, max-count] logprobs_arr = BK.get_value(sel_logprobs) # [*, max-count] valid_arr = BK.get_value(sel_valid_mask) # [*, max-count] all_items = [] bsize, mc = valid_arr.shape for one_idxes, one_valids, one_lab_idxes, one_logprobs, one_sent in \ zip(head_idxes_arr, valid_arr, lab_idxes_arr, logprobs_arr, insts): sid = one_sent.sid partial_id0 = f"{one_sent.doc.doc_id}-s{one_sent.sid}-i" for this_i in range(mc): this_valid = float(one_valids[this_i]) if this_valid == 0: # must be compact assert np.all(one_valids[this_i:]==0.) all_items.extend([None] * (mc-this_i)) break # todo(note): we need to assign various info at the outside this_mention = Mention(HardSpan(sid, int(one_idxes[this_i]), None, None)) # todo(note): where to filter None? this_hlidx = extractor.idx2hlidx(one_lab_idxes[this_i]) all_items.append(item_creator(partial_id0+str(this_i), this_mention, this_hlidx, float(one_logprobs[this_i]))) # only return the items and the ones useful for later steps: List(sent)[List(items)], *[*, max-count] ret_items = np.asarray(all_items, dtype=object).reshape((bsize, mc)) return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def __call__(self, input_repr, mask_arr, require_loss, require_pred, gold_pos_arr=None): enc0_expr = self.enc(input_repr, mask_arr) # [*, len, d] # enc1_expr = enc0_expr pos_probs, pos_losses_expr, pos_preds_expr = None, None, None if self.jpos_multitask: # get probabilities pos_logits = self.pred(enc0_expr) # [*, len, nl] pos_probs = BK.softmax(pos_logits, dim=-1) # stacking for input -> output if self.jpos_stacking: enc1_expr = enc0_expr + BK.matmul(pos_probs, self.pos_weights) # simple cross entropy loss if require_loss and self.jpos_lambda > 0.: gold_probs = BK.gather_one_lastdim( pos_probs, gold_pos_arr).squeeze(-1) # [*, len] # todo(warn): multiplying the factor here, but not maksing here (masking in the final steps) pos_losses_expr = (-self.jpos_lambda) * gold_probs.log() # simple argmax for prediction if require_pred and self.jpos_decode: pos_preds_expr = pos_probs.max(dim=-1)[1] return enc1_expr, (pos_probs, pos_losses_expr, pos_preds_expr)
def select_plain(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: flattened_states, cur_arc_scores, scoring_mask_ct = candidates cur_cache = self.cache cur_bsize = len(flattened_states) cur_slen = cur_cache.max_slen cur_arc_scores_flattend = cur_arc_scores.view([cur_bsize, -1]) # [bs, Lm*Lh] if mode == "topk": # arcs [*, k] topk_arc_scores, topk_arc_idxes = BK.topk( cur_arc_scores_flattend, min(k_arc, BK.get_shape(cur_arc_scores_flattend, -1)), dim=-1, sorted=False) topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen # [m, h] # labels [*, k, k'] cur_label_scores = cur_cache.get_selected_label_scores( topk_m, topk_h, self.mw_arc, self.mw_label) topk_label_scores, topk_label_idxes = BK.topk( cur_label_scores, min(k_label, BK.get_shape(cur_label_scores, -1)), dim=-1, sorted=False) return self._new_states(flattened_states, scoring_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes) elif mode == "": return [[]] * cur_bsize # todo(+N): other modes like sampling to be implemented: sample, topk-sample else: raise NotImplementedError(mode)
def _get_rel_dist(self, len_q: int, len_k: int = None): if len_k is None: len_k = len_q dist_x = BK.arange_idx(0, len_k).unsqueeze(0) # [1, len_k] dist_y = BK.arange_idx(0, len_q).unsqueeze(1) # [len_q, 1] distance = dist_x - dist_y # [len_q, len_k] return distance
def _score_label_full(self, scoring_expr_pack, mask_expr, training, margin, gold_heads_expr=None, gold_labels_expr=None): _, _, lm_expr, lh_expr = scoring_expr_pack # [BS, len-m, len-h, L] full_label_score = self.scorer.score_label_all(lm_expr, lh_expr, mask_expr, mask_expr) # # set diag to small values # todo(warn): handled specifically in algorithms # maxlen = BK.get_shape(full_label_score, 1) # full_label_score += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1) # margin? -- specially reshaping if training and margin > 0.: full_shape = BK.get_shape(full_label_score) # combine last two dim combiend_score_expr = full_label_score.view(full_shape[:-2] + [-1]) combined_idx_expr = gold_heads_expr * full_shape[ -1] + gold_labels_expr combined_changed_score = BK.minus_margin(combiend_score_expr, combined_idx_expr, margin) full_label_score = combined_changed_score.view(full_shape) return full_label_score
def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): conf = self.conf # score scores_t = self._score(repr_t) # [bs, ?+rlen, D] # get gold gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long) # [bs, ?+rlen] for bidx, inst in enumerate(insts): cur_seq_idxes = getattr(inst, self.attr_name).idxes if self.add_root_token: gold_pidxes[bidx, 1:1 + len(cur_seq_idxes)] = cur_seq_idxes else: gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes # get loss margin = self.margin.value gold_pidxes_t = BK.input_idx(gold_pidxes) gold_pidxes_t *= (gold_pidxes_t < self.pred_out_dim).long() # 0 means invalid ones!! loss_mask_t = (gold_pidxes_t > 0).float() * mask_t # [bs, ?+rlen] lab_losses_t = BK.loss_nll(scores_t, gold_pidxes_t, margin=margin) # [bs, ?+rlen] # argmax _, argmax_idxes = scores_t.max(-1) pred_corrs = (argmax_idxes == gold_pidxes_t).float() * loss_mask_t # compile loss lab_loss = LossHelper.compile_leaf_info("slab", lab_losses_t.sum(), loss_mask_t.sum(), corr=pred_corrs.sum()) return self._compile_component_loss(self.pname, [lab_loss])
def get_losses_global_hinge(full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr, clamping=True): # combine the last two dimension full_shape = BK.get_shape(full_score_expr) # [*, m, h*L] last_size = full_shape[-1] combiend_score_expr = full_score_expr.view(full_shape[:-2] + [-1]) # [*, m] gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr pred_combined_idx_expr = pred_heads_expr * last_size + pred_labels_expr # [*, m] gold_scores = BK.gather_one_lastdim(combiend_score_expr, gold_combined_idx_expr).squeeze(-1) pred_scores = BK.gather_one_lastdim(combiend_score_expr, pred_combined_idx_expr).squeeze(-1) # todo(warn): be aware of search error! # hinge_losses = BK.clamp(pred_scores - gold_scores, min=0.) # this is previous version hinge_losses = pred_scores - gold_scores # [*, len] if clamping: valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] return hinge_losses * valid_losses else: # for this mode, will there be problems of search error? Maybe rare. return hinge_losses
def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf, vpack: VocabPackage): super().__init__(pc, None, None) self.conf = conf # vocab and padder self.word_vocab = vpack.get_voc("word") self.padder = DataPadder( 2, pad_vals=self.word_vocab.pad, mask_range=2) # todo(note): <pad>-id is very large # models self.hid_layer = self.add_sub_node( "hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act)) self.pred_layer = self.add_sub_node( "pred", Affine(pc, conf.hid_dim, conf.max_pred_rank + 1, init_rop=NoDropRop())) if conf.init_pred_from_pretrain: npvec = vpack.get_emb("word") if npvec is None: zwarn( "Pretrained vector not provided, skip init pred embeddings!!" ) else: with BK.no_grad_env(): self.pred_layer.ws[0].copy_( BK.input_real(npvec[:conf.max_pred_rank + 1].T)) zlog( f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1})." )
def _losses_global_prob(self, full_score_expr, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr): # combine the last two dimension full_shape = BK.get_shape(full_score_expr) last_size = full_shape[-1] # [*, m, h*L] combined_marginals_expr = marginals_expr.view(full_shape[:-2] + [-1]) # # todo(warn): make sure sum to 1., handled in algorithm instead # combined_marginals_expr = combined_marginals_expr / combined_marginals_expr.sum(dim=-1, keepdim=True) # [*, m] gold_combined_idx_expr = gold_heads_expr * last_size + gold_labels_expr # [*, m, h, L] gradients = BK.minus_margin(combined_marginals_expr, gold_combined_idx_expr, 1.).view(full_shape) # the gradients on h are already 0. from the marginal algorithm gradients_masked = gradients * mask_expr.unsqueeze(-1).unsqueeze( -1) * mask_expr.unsqueeze(-2).unsqueeze(-1) # for the h-dimension, need to divide by the real length. # todo(warn): this values should be directly summed rather than averaged, since directly from loss fake_losses = (full_score_expr * gradients_masked).sum(-1).sum( -1) # [BS, m] # todo(warn): be aware of search-error-like output constrains; # but this clamp for all is not good for loss-prob, dealt at outside with unproj-mask. # <bad> fake_losses = BK.clamp(fake_losses, min=0.) return fake_losses
def _score(self, repr_t, attn_t, mask_t): conf = self.conf # ----- repr_m = self.pre_aff_m(repr_t) # [bs, slen, S] repr_h = self.pre_aff_h(repr_t) # [bs, slen, S] scores0 = self.dps_node.paired_score( repr_m, repr_h, inputp=attn_t) # [bs, len_q, len_k, 1+N] # mask at outside slen = BK.get_shape(mask_t, -1) score_mask = BK.constants(BK.get_shape(scores0)[:-1], 1.) # [bs, len_q, len_k] score_mask *= (1. - BK.eye(slen)) # no diag score_mask *= mask_t.unsqueeze(-1) # input mask at len_k score_mask *= mask_t.unsqueeze(-2) # input mask at len_q NEG = Constants.REAL_PRAC_MIN scores1 = scores0 + NEG * (1. - score_mask.unsqueeze(-1) ) # [bs, len_q, len_k, 1+N] # add fixed idx0 scores if set if conf.fix_s0: fix_s0_mask_t = BK.input_real(self.dps_s0_mask) # [1+N] scores1 = ( 1. - fix_s0_mask_t ) * scores1 + fix_s0_mask_t * conf.fix_s0_val # [bs, len_q, len_k, 1+N] # minus s0 if conf.minus_s0: scores1 = scores1 - scores1.narrow(-1, 0, 1) # minus idx=0 scores return scores1, score_mask
def forward_features(self, ids_expr, mask_expr, typeids_expr, other_embed_exprs: List): bmodel = self.model bmodel_embedding = bmodel.embeddings bmodel_encoder = bmodel.encoder # prepare attention_mask = mask_expr token_type_ids = BK.zeros(BK.get_shape( ids_expr)).long() if typeids_expr is None else typeids_expr extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # extended_attention_mask = extended_attention_mask.to(dtype=next(bmodel.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # embeddings cur_layer = 0 if self.trainable_min_layer <= 0: last_output = bmodel_embedding(ids_expr, position_ids=None, token_type_ids=token_type_ids) else: with BK.no_grad_env(): last_output = bmodel_embedding(ids_expr, position_ids=None, token_type_ids=token_type_ids) # extra embeddings (this implies overall graident requirements!!) for one_eidx, one_embed in enumerate(self.other_embeds): last_output += one_embed( other_embed_exprs[one_eidx]) # [bs, slen, D] # ===== all_outputs = [] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 # todo(note): be careful about the indexes! # not-trainable encoders trainable_min_layer_idx = max(0, self.trainable_min_layer - 1) with BK.no_grad_env(): for layer_module in bmodel_encoder.layer[:trainable_min_layer_idx]: last_output = layer_module(last_output, extended_attention_mask, None)[0] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 # trainable encoders for layer_module in bmodel_encoder.layer[trainable_min_layer_idx:self. output_max_layer]: last_output = layer_module(last_output, extended_attention_mask, None)[0] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 assert cur_layer == self.output_max_layer + 1 # stack if len(all_outputs) == 1: ret_expr = all_outputs[0].unsqueeze(-2) else: ret_expr = BK.stack(all_outputs, -2) # [BS, SLEN, LAYER, D] final_ret_exp = self.output_f(ret_expr) return final_ret_exp