def predict(self, repr_ef, repr_evt, lab_ef, lab_evt, mask_ef=None, mask_evt=None, ret_full_logprobs=False): # ----- ret_shape = BK.get_shape(lab_ef)[:-1] + [ BK.get_shape(lab_ef, -1), BK.get_shape(lab_evt, -1) ] if np.prod(ret_shape) == 0: if ret_full_logprobs: return BK.zeros(ret_shape + [self.num_label]) else: return BK.zeros(ret_shape), BK.zeros(ret_shape).long() # ----- # todo(note): +1 for space of DROPED(UNK) full_score = self._score(repr_ef, repr_evt, lab_ef + 1, lab_evt + 1) # [*, len-ef, len-evt, D] full_logprobs = BK.log_softmax(full_score, -1) if ret_full_logprobs: return full_logprobs else: # greedy maximum decode ret_logprobs, ret_idxes = full_logprobs.max( -1) # [*, len-ef, len-evt] # mask non-valid ones if mask_ef is not None: ret_idxes *= (mask_ef.unsqueeze(-1)).long() if mask_evt is not None: ret_idxes *= (mask_evt.unsqueeze(-2)).long() return ret_logprobs, ret_idxes
def calculate_repr(self, cur_t, par_t, label_t, par_mask_t, chs_t, chs_label_t, chs_mask_t, chs_valid_mask_t): ret_t = cur_t # [*, D] # padding 0 if not using labels dim_label = self.dim_label # child features if self.use_chs and chs_t is not None: if self.use_label_feat: chs_label_rt = self.label_embeddings( chs_label_t) # [*, max-chs, dlab] else: labels_shape = BK.get_shape(chs_t) labels_shape[-1] = dim_label chs_label_rt = BK.zeros(labels_shape) chs_input_t = BK.concat([chs_t, chs_label_rt], -1) chs_feat0 = self.chs_reprer(cur_t, chs_input_t, chs_mask_t, chs_valid_mask_t) chs_feat = self.chs_ff(chs_feat0) ret_t += chs_feat # parent features if self.use_par and par_t is not None: if self.use_label_feat: cur_label_t = self.label_embeddings(label_t) # [*, dlab] else: labels_shape = BK.get_shape(par_t) labels_shape[-1] = dim_label cur_label_t = BK.zeros(labels_shape) par_feat = self.par_ff([par_t, cur_label_t]) if par_mask_t is not None: par_feat *= par_mask_t.unsqueeze(-1) ret_t += par_feat return ret_t
def _fb_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt, margin): # get the gold idxes arg_linker = self.arg_linker bsize, len_ef = ef_items.shape bsize2, len_evt = evt_items.shape assert bsize == bsize2 gold_idxes = np.zeros([bsize, len_ef, len_evt], dtype=np.long) for one_gold_idxes, one_ef_items, one_evt_items in zip(gold_idxes, ef_items, evt_items): # todo(note): check each pair for ef_idx, one_ef in enumerate(one_ef_items): if one_ef is None: continue role_map = {id(z.evt): z.role_idx for z in one_ef.links} # todo(note): since we get the original linked ones for evt_idx, one_evt in enumerate(one_evt_items): pairwise_role_hlidx = role_map.get(id(one_evt)) if pairwise_role_hlidx is not None: pairwise_role_idx = arg_linker.hlidx2idx(pairwise_role_hlidx) assert pairwise_role_idx > 0 one_gold_idxes[ef_idx, evt_idx] = pairwise_role_idx # get loss repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2) # [*, len-ef, D] repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2) # [*, len-evt, D] if np.prod(gold_idxes.shape) == 0: # no instances! return [[BK.zeros([]), BK.zeros([])]] else: gold_idxes_t = BK.input_idx(gold_idxes) return arg_linker.loss(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask, gold_idxes_t, margin)
def init_call(self, src): # init accumulated attn: all 0. src_shape = BK.get_shape(src) attn_shape = src_shape[:-1] + [src_shape[-2], self.attn_count ] # [*, len_q, len_k, head] cache = VRecCache() cache.orig_t = src cache.rec_t = src # initially, the same as "orig_t" cache.rec_lstm_c_t = BK.zeros(BK.get_shape(src)) # initially zero cache.accu_attn = BK.zeros(attn_shape) return cache
def loss(self, ms_items: List, bert_expr, basic_expr): conf = self.conf bsize = len(ms_items) # use gold targets: only use positive samples!! offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.events, True, False, 0., 0., True) # [bs, ?] realis_flist = [(-1 if (z is None or z.realis_idx is None) else z.realis_idx) for z in items_arr.flatten()] realis_t = BK.input_idx(realis_flist).view(items_arr.shape) # [bs, ?] realis_mask = (realis_t >= 0).float() realis_t.clamp_(min=0) # make sure all idxes are legal # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] # realis, types # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build losses loss_item_realis = self._get_one_loss(self.realis_predictor, hiddens, realis_t, realis_mask, conf.lambda_realis) loss_item_type = self._get_one_loss(self.type_predictor, hiddens, labels_t, masks_t, conf.lambda_type) return [loss_item_realis, loss_item_type]
def get_selected_label_scores(self, idxes_m_t, idxes_h_t, bsize_range_t, oracle_mask_t, oracle_label_t, arc_margin: float, label_margin: float): # todo(note): in this mode, no repeated arc_margin dim1_range_t = bsize_range_t dim2_range_t = dim1_range_t.unsqueeze(-1) if self.system_labeled: selected_m_cache = [ z[dim2_range_t, idxes_m_t] for z in self.mod_label_cache ] selected_h_repr = self.head_label_cache[dim2_range_t, idxes_h_t] ret = self.scorer.score_label(selected_m_cache, selected_h_repr) # [*, k, labels] if label_margin > 0.: oracle_label_idxes = oracle_label_t[dim2_range_t, idxes_m_t, idxes_h_t].unsqueeze( -1) # [*, k, 1] of int ret.scatter_add_( -1, oracle_label_idxes, BK.constants(oracle_label_idxes.shape, -label_margin)) else: # todo(note): otherwise, simply put zeros (with idx=0 as the slightly best to be consistent) ret = BK.zeros(BK.get_shape(idxes_m_t) + [self.num_label]) ret[:, :, 0] += 0.01 if self.g1_lab_scores is not None: ret += self.g1_lab_scores[dim2_range_t, idxes_m_t, idxes_h_t] return ret
def loss(self, ms_items: List, bert_expr): conf = self.conf max_range = self.conf.max_range bsize = len(ms_items) # collect instances col_efs, _, col_bidxes_t, col_hidxes_t, col_ldists_t, col_rdists_t = self._collect_insts( ms_items, True) if len(col_efs) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz], [zzz, zzz, zzz]] left_scores, right_scores = self._score(bert_expr, col_bidxes_t, col_hidxes_t) # [N, R] if conf.use_binary_scorer: left_binaries, right_binaries = (BK.arange_idx(max_range)<=col_ldists_t.unsqueeze(-1)).float(), \ (BK.arange_idx(max_range)<=col_rdists_t.unsqueeze(-1)).float() # [N,R] left_losses = BK.binary_cross_entropy_with_logits( left_scores, left_binaries, reduction='none')[:, 1:] right_losses = BK.binary_cross_entropy_with_logits( right_scores, right_binaries, reduction='none')[:, 1:] left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0) * (max_range - 1)) else: left_losses = BK.loss_nll(left_scores, col_ldists_t) right_losses = BK.loss_nll(right_scores, col_rdists_t) left_count = right_count = BK.input_real( BK.get_shape(left_losses, 0)) return [[left_losses.sum(), left_count, left_count], [right_losses.sum(), right_count, right_count]]
def _loss(self, enc_repr, action_list: List[EfAction], arc_weight_list: List[float], label_weight_list: List[float], bidxes_list: List[int]): # 1. collect (batched) features; todo(note): use prev state for scoring hm_features = self.hm_feature_getter.get_hm_features(action_list, [a.state_from for a in action_list]) # 2. get new sreprs scorer = self.scorer s_enc = self.slayer bsize_range_t = BK.input_idx(bidxes_list) node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], enc_repr, bsize_range_t) node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], enc_repr, bsize_range_t) # label loss if self.system_labeled: node_lh_expr, _ = scorer.transform_space_label(node_h_srepr, True, False) _, node_lm_pack = scorer.transform_space_label(node_m_srepr, False, True) label_scores_full = scorer.score_label(node_lm_pack, node_lh_expr) # [*, Lab] label_scores = BK.gather_one_lastdim(label_scores_full, [a.label for a in action_list]).squeeze(-1) final_label_loss_sum = (label_scores * BK.input_real(label_weight_list)).sum() else: label_scores = final_label_loss_sum = BK.zeros([]) # arc loss node_ah_expr, _ = scorer.transform_space_arc(node_h_srepr, True, False) _, node_am_pack = scorer.transform_space_arc(node_m_srepr, False, True) arc_scores = scorer.score_arc(node_am_pack, node_ah_expr).squeeze(-1) final_arc_loss_sum = (arc_scores * BK.input_real(arc_weight_list)).sum() # score reg return final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores
def loss(self, ms_items: List, bert_expr, basic_expr, margin=0.): conf = self.conf bsize = len(ms_items) # build targets (include all sents) # todo(note): use "x.entity_fillers" for getting gold args offsets_t, masks_t, _, items_arr, labels_t = PrepHelper.prep_targets( ms_items, lambda x: x.entity_fillers, True, True, conf.train_neg_rate, conf.train_neg_rate_outside, True) labels_t.clamp_(max=1) # either 0 or 1 # ----- # return 0 if all no targets if BK.get_shape(offsets_t, -1) == 0: zzz = BK.zeros([]) return [[zzz, zzz, zzz]] # ----- arange_t = BK.arange_idx(bsize).unsqueeze(-1) # [bsize, 1] sel_bert_t = bert_expr[arange_t, offsets_t] # [bsize, ?, Fold, D] sel_basic_t = None if basic_expr is None else basic_expr[ arange_t, offsets_t] # [bsize, ?, D'] hiddens = self.adp(sel_bert_t, sel_basic_t, []) # [bsize, ?, D"] # build loss logits = self.predictor(hiddens) # [bsize, ?, Out] log_probs = BK.log_softmax(logits, -1) picked_log_probs = -BK.gather_one_lastdim(log_probs, labels_t).squeeze( -1) # [bsize, ?] masked_losses = picked_log_probs * masks_t # loss_sum, loss_count, gold_count return [[ masked_losses.sum(), masks_t.sum(), (labels_t > 0).float().sum() ]]
def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr): mask_idxes, mask_valids = BK.mask2idx( BK.input_real(pred_mask_repl_arr)) # [bsize, ?] if BK.get_shape(mask_idxes, -1) == 0: # no loss zzz = BK.zeros([]) return [[zzz, zzz, zzz]] else: target_reprs = BK.gather_first_dims(repr_t, mask_idxes, 1) # [bsize, ?, *] target_hids = self.hid_layer(target_reprs) target_scores = self.pred_layer(target_hids) # [bsize, ?, V] pred_idx_t = BK.input_idx(pred_idx_arr) # [bsize, slen] target_idx_t = pred_idx_t.gather(-1, mask_idxes) # [bsize, ?] target_idx_t[(mask_valids < 1.)] = 0 # make sure invalid ones in range # get loss pred_losses = BK.loss_nll(target_scores, target_idx_t) # [bsize, ?] pred_loss_sum = (pred_losses * mask_valids).sum() pred_loss_count = mask_valids.sum() # argmax _, argmax_idxes = target_scores.max(-1) pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids pred_corr_count = pred_corrs.sum() return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
def loss(self, repr_ef, repr_evt, lab_ef, lab_evt, mask_ef, mask_evt, gold_idxes, margin=0.): conf = self.conf # ----- if np.prod(BK.get_shape(gold_idxes)) == 0: return [[BK.zeros([]), BK.zeros([])]] # ----- # todo(note): +1 for space of DROPED(UNK) lab_ef = self._dropout_idxes(lab_ef + 1, conf.train_drop_ef_lab) lab_evt = self._dropout_idxes(lab_evt + 1, conf.train_drop_evt_lab) if conf.linker_ef_detach: repr_ef = repr_ef.detach() if conf.linker_evt_detach: repr_evt = repr_evt.detach() full_score = self._score(repr_ef, repr_evt, lab_ef, lab_evt) # [*, len-ef, len-evt, D] if margin > 0.: aug_score = BK.zeros(BK.get_shape(full_score)) + margin aug_score.scatter_(-1, gold_idxes.unsqueeze(-1), 0.) full_score += aug_score full_logprobs = BK.log_softmax(full_score, -1) gold_logprobs = full_logprobs.gather(-1, gold_idxes.unsqueeze(-1)).squeeze( -1) # [*, len-ef, len-evt] # sampling and mask loss_mask = mask_ef.unsqueeze(-1) * mask_evt.unsqueeze(-2) # ==== # first select examples (randomly) sel_mask = (BK.rand(BK.get_shape(loss_mask)) < conf.train_min_rate).float() # [*, len-ef, len-evt] # add gold and exclude pad sel_mask += (gold_idxes > 0).float() sel_mask.clamp_(max=1.) loss_mask *= sel_mask # ===== loss_sum = -(gold_logprobs * loss_mask).sum() loss_count = loss_mask.sum() ret_losses = [[loss_sum, loss_count]] return ret_losses
def forward_features(self, ids_expr, mask_expr, typeids_expr, other_embed_exprs: List): bmodel = self.model bmodel_embedding = bmodel.embeddings bmodel_encoder = bmodel.encoder # prepare attention_mask = mask_expr token_type_ids = BK.zeros(BK.get_shape( ids_expr)).long() if typeids_expr is None else typeids_expr extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # extended_attention_mask = extended_attention_mask.to(dtype=next(bmodel.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # embeddings cur_layer = 0 if self.trainable_min_layer <= 0: last_output = bmodel_embedding(ids_expr, position_ids=None, token_type_ids=token_type_ids) else: with BK.no_grad_env(): last_output = bmodel_embedding(ids_expr, position_ids=None, token_type_ids=token_type_ids) # extra embeddings (this implies overall graident requirements!!) for one_eidx, one_embed in enumerate(self.other_embeds): last_output += one_embed( other_embed_exprs[one_eidx]) # [bs, slen, D] # ===== all_outputs = [] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 # todo(note): be careful about the indexes! # not-trainable encoders trainable_min_layer_idx = max(0, self.trainable_min_layer - 1) with BK.no_grad_env(): for layer_module in bmodel_encoder.layer[:trainable_min_layer_idx]: last_output = layer_module(last_output, extended_attention_mask, None)[0] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 # trainable encoders for layer_module in bmodel_encoder.layer[trainable_min_layer_idx:self. output_max_layer]: last_output = layer_module(last_output, extended_attention_mask, None)[0] if self.layer_is_output[cur_layer]: all_outputs.append(last_output) cur_layer += 1 assert cur_layer == self.output_max_layer + 1 # stack if len(all_outputs) == 1: ret_expr = all_outputs[0].unsqueeze(-2) else: ret_expr = BK.stack(all_outputs, -2) # [BS, SLEN, LAYER, D] final_ret_exp = self.output_f(ret_expr) return final_ret_exp
def _hl_lookup(self, sel_lab_idxes): # the embeddings sel_shape = BK.get_shape(sel_lab_idxes) if sel_shape[-1] == 0: sel_lab_embeds = BK.zeros(sel_shape + [self.conf.lab_conf.n_dim]) else: assert not self.hl.conf.use_lookup_soft, "Cannot do soft-lookup in this mode" sel_lab_embeds = self.hl.lookup(sel_lab_idxes) return sel_lab_embeds
def _score_sentence(self, scores, mask, tags): """ input: scores: variable (seq_len, batch, tag_size, tag_size) mask: (batch, seq_len) tags: tensor (batch, seq_len) output: score: sum of score for gold sequences within whole batch """ # Gives the score of a provided tag sequence batch_size = scores.size(1) seq_len = scores.size(0) tag_size = scores.size(2) ## convert tag value into a new format, recorded label bigram information to index # new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) # if self.gpu: # new_tags = new_tags.cuda() new_tags = BK.zeros((batch_size, seq_len)).long() for idx in range(seq_len): if idx == 0: ## start -> first score new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0] else: new_tags[:, idx] = tags[:, idx - 1] * tag_size + tags[:, idx] ## transition for label to STOP_TAG end_transition = self.transitions[:, STOP_TAG].contiguous().view( 1, tag_size).expand(batch_size, tag_size) ## length for batch, last word position = length - 1 length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() ## index the label id of last word end_ids = torch.gather(tags, 1, length_mask - 1) ## index the transition score for end_id to STOP_TAG end_energy = torch.gather(end_transition, 1, end_ids) ## convert tag as (seq_len, batch_size, 1) new_tags = new_tags.transpose(1, 0).contiguous().view( seq_len, batch_size, 1) ### need convert tags id to search from 400 positions of scores tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view( seq_len, batch_size) # seq_len * bat_size ## mask transpose to (seq_len, batch_size) tg_energy = tg_energy.masked_select(mask.transpose(1, 0)) # ## calculate the score from START_TAG to first label # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) # start_energy = torch.gather(start_transition, 1, tags[0,:]) ## add all score together # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() gold_score = tg_energy.sum() + end_energy.sum() return gold_score
def inference_on_batch(self, insts: List[GeneralSentence], **kwargs): conf = self.conf self.refresh_batch(False) # print(f"{len(insts)}: {insts[0].sid}") with BK.no_grad_env(): # decode for dpar input_map = self.model.inputter(insts) emb_t, mask_t, enc_t, cache, _ = self.model._emb_and_enc(input_map, collect_loss=False) input_t = BK.concat(cache.list_attn, -1) # [bs, slen, slen, L*H] self.dpar.predict(insts, BK.zeros([1,1]), input_t, mask_t) return {}
def lookup(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf bsize = len(insts) # first get gold/input info, also multiple valid-masks gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h( insts) # step 1: no selection, simply forward using gold_masks sel_idxes, sel_valid_mask = BK.mask2idx(gold_masks) # [*, max-count] sel_gold_idxes = gold_idxes.gather(-1, sel_idxes) sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes) # todo(+N): only get items by head position! _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value( sel_idxes) sel_items = gold_items_arr[_tmp_i0, _tmp_i1] # [*, mc] sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1] # step 2: encoding and labeling sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: sel_lab_idxes = sel_gold_idxes sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim]) ret_items = sel_items # dim-1==0 else: # sel_hid_exprs = self._enc(input_expr, input_mask, sel_idxes) # [*, mc, DLab] sel_lab_idxes = sel_gold_idxes sel_lab_embeds = self.hl.lookup( sel_lab_idxes) # todo(note): here no softlookup? ret_items = sel_items # second type if self.use_secondary_type: sel2_lab_idxes = sel_gold_idxes2 sel2_lab_embeds = self.hl.lookup( sel2_lab_idxes) # todo(note): here no softlookup? sel2_valid_mask = (sel2_lab_idxes > 0).float() # combine the two if sel2_lab_idxes.sum().item( ) > 0: # if there are any gold sectypes ret_items = np.concatenate([ret_items, sel2_items], -1) # [*, mc*2] sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat( [sel_valid_mask, sel2_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat( [sel_lab_embeds, sel2_lab_embeds], -2) # step 3: exclude nil assuming no deliberate nil in gold/inputs if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items) # step 4: return # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] # mask out invalid items with None ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def _special_score( one_score): # specially change ablpair scores into [bs,m,h,*] root_score = one_score[:, :, 0].unsqueeze(2) # [bs, rlen, 1, *] tmp_shape = BK.get_shape(root_score) tmp_shape[1] = 1 # [bs, 1, 1, *] padded_root_score = BK.concat([BK.zeros(tmp_shape), root_score], dim=1) # [bs, rlen+1, 1, *] final_score = BK.concat( [padded_root_score, one_score.transpose(1, 2)], dim=2) # [bs, rlen+1[m], rlen+1[h], *] return final_score
def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1., rand_gen=None, assign_attns=False, **kwargs): self.refresh_batch(training) # get inputs with models with BK.no_grad_env(): input_map = self.model.inputter(insts) emb_t, mask_t, enc_t, cache, enc_loss = self.model._emb_and_enc(input_map, collect_loss=True) input_t = BK.concat(cache.list_attn, -1) # [bs, slen, slen, L*H] losses = [self.dpar.loss(insts, BK.zeros([1,1]), input_t, mask_t)] # ----- info = self.collect_loss_and_backward(losses, training, loss_factor) info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)}) return info
def _pmask2idxes(self, pred_mask): orig_shape = BK.get_shape(pred_mask) dim_type = orig_shape[-1] flattened_mask = pred_mask.view(orig_shape[:-2] + [-1]) # [*, slen*L] f_idxes, sel_valid_mask = BK.mask2idx(flattened_mask) # [*, max-count] # then back to the two dimensions sel_idxes, sel_lab_idxes = f_idxes // dim_type, f_idxes % dim_type # the embeddings sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: sel_lab_embeds = BK.zeros(sel_shape + [self.conf.lab_conf.n_dim]) else: assert not self.hl.conf.use_lookup_soft, "Cannot do soft-lookup in this mode" sel_lab_embeds = self.hl.lookup(sel_lab_idxes) return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.): # todo(+N): currently margin is not used conf = self.conf bsize = len(insts) arange_t = BK.arange_idx(bsize) assert conf.train_force, "currently only have forced training" # get the gold ones gold_widxes, gold_lidxes, gold_vmasks, ret_items, _ = self.batch_inputs_g1(insts) # [*, ?] # for all the steps num_step = BK.get_shape(gold_widxes, -1) # recurrent states hard_coverage = BK.zeros(BK.get_shape(input_mask)) # [*, slen] prev_state = self.rnn_unit.zero_init_hidden(bsize) # tuple([*, D], ) all_tok_logprobs, all_lab_logprobs = [], [] for cstep in range(num_step): slice_widx, slice_lidx = gold_widxes[:,cstep], gold_lidxes[:,cstep] _, sel_tok_logprobs, _, sel_lab_logprobs, _, next_state = \ self._step(input_expr, input_mask, hard_coverage, prev_state, slice_widx, slice_lidx, None) all_tok_logprobs.append(sel_tok_logprobs) # add one of [*, 1] all_lab_logprobs.append(sel_lab_logprobs) hard_coverage = BK.copy(hard_coverage) # todo(note): cannot modify inplace! hard_coverage[arange_t, slice_widx] += 1. prev_state = [z.squeeze(-2) for z in next_state] # concat all the loss and mask # todo(note): no need to use gold_valid since things are telled in vmasks cat_tok_logprobs = BK.concat(all_tok_logprobs, -1) * gold_vmasks # [*, steps] cat_lab_logprobs = BK.concat(all_lab_logprobs, -1) * gold_vmasks loss_sum = - (cat_tok_logprobs.sum() * conf.lambda_att + cat_lab_logprobs.sum() * conf.lambda_lab) # todo(+N): here we are dividing lab_logprobs with the all-count, do we need to separate? loss_count = gold_vmasks.sum() ret_losses = [[loss_sum, loss_count]] # ===== # make eos unvalid for return ret_valid_mask = gold_vmasks * (gold_widxes>0).float() # embeddings sel_lab_embeds = self._hl_lookup(gold_lidxes) return ret_losses, ret_items, gold_widxes, ret_valid_mask, gold_lidxes, sel_lab_embeds
def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.): conf = self.conf bsize = len(insts) # first get gold info, also multiple valid-masks gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h( insts) input_mask = input_mask * gold_valid.unsqueeze(-1) # [*, slen] # step 1: selector if conf.use_selector: sel_loss, sel_mask = self.sel.loss(input_expr, input_mask, gold_masks, margin=margin) else: sel_loss, sel_mask = None, self._select_cands_training( input_mask, gold_masks, conf.train_min_rate) sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask) # [*, max-count] sel_gold_idxes = gold_idxes.gather(-1, sel_idxes) sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes) # todo(+N): only get items by head position! _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value( sel_idxes) sel_items = gold_items_arr[_tmp_i0, _tmp_i1] # [*, mc] sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1] # step 2: encoding and labeling # if we select nothing # ----- debug # zlog(f"fb-extractor 1: shape sel_idxes = {sel_idxes.shape}") # ----- sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: lab_loss = [[BK.zeros([]), BK.zeros([])]] sel2_lab_loss = [[BK.zeros([]), BK.zeros([])] ] if self.use_secondary_type else None sel_lab_idxes = sel_gold_idxes sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim]) ret_items = sel_items # dim-1==0 else: sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask, sel_idxes) # [*, mc, DLab] lab_loss, sel_lab_idxes, sel_lab_embeds = self.hl.loss( sel_hid_exprs, sel_valid_mask, sel_gold_idxes, margin=margin) if conf.train_gold_corr: sel_lab_idxes = sel_gold_idxes if not self.hl.conf.use_lookup_soft: sel_lab_embeds = self.hl.lookup(sel_lab_idxes) ret_items = sel_items # ===== if self.use_secondary_type: sectype_embeds = self.t1tot2(sel_lab_idxes) # [*, mc, D] if conf.sectype_noback_enc: sel2_input = sel_hid_exprs.detach( ) + sectype_embeds # [*, mc, D] else: sel2_input = sel_hid_exprs + sectype_embeds # [*, mc, D] # ===== # sepcial for the sectype mask (sample it within the gold ones) sel2_valid_mask = self._select_cands_training( (sel_gold_idxes > 0).float(), (sel_gold_idxes2 > 0).float(), conf.train_min_rate_s2) # ===== sel2_lab_loss, sel2_lab_idxes, sel2_lab_embeds = self.hl.loss( sel2_input, sel2_valid_mask, sel_gold_idxes2, margin=margin) if conf.train_gold_corr: sel2_lab_idxes = sel_gold_idxes2 if not self.hl.conf.use_lookup_soft: sel2_lab_embeds = self.hl.lookup(sel2_lab_idxes) if conf.sectype_t2ift1: sel2_lab_idxes = sel2_lab_idxes * (sel_lab_idxes > 0).long( ) # pred t2 only if t1 is not 0 (nil) # combine the two if sel2_lab_idxes.sum().item( ) > 0: # if there are any gold sectypes ret_items = np.concatenate([ret_items, sel2_items], -1) # [*, mc*2] sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat( [sel_valid_mask, sel2_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat( [sel_lab_embeds, sel2_lab_embeds], -2) else: sel2_lab_loss = None # ===== # step 3: exclude nil and return if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items) # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] # step 4: finally prepare loss and items for one_loss in lab_loss: one_loss[0] *= conf.lambda_ne ret_losses = lab_loss if sel2_lab_loss is not None: for one_loss in sel2_lab_loss: one_loss[0] *= conf.lambda_ne2 ret_losses = ret_losses + sel2_lab_loss if sel_loss is not None: for one_loss in sel_loss: one_loss[0] *= conf.lambda_ns ret_losses = ret_losses + sel_loss # ----- debug # zlog(f"fb-extractor 2: shape sel_idxes = {sel_idxes.shape}") # ----- # mask out invalid items with None ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None return ret_losses, ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def loss(self, repr_t, attn_t, mask_t, disturb_keep_arr, **kwargs): conf = self.conf CR, PR = conf.cand_range, conf.pred_range # ----- mask_single = BK.copy(mask_t) # no predictions for ARTI_ROOT if self.add_root_token: mask_single[:, 0] = 0. # [bs, slen] # casting predicting range cur_slen = BK.get_shape(mask_single, -1) arange_t = BK.arange_idx(cur_slen) # [slen] # [1, len] - [len, 1] = [len, len] reldist_t = (arange_t.unsqueeze(-2) - arange_t.unsqueeze(-1) ) # [slen, slen] mask_pair = ((reldist_t.abs() <= CR) & (reldist_t != 0)).float() # within CR-range; [slen, slen] mask_pair = mask_pair * mask_single.unsqueeze( -1) * mask_single.unsqueeze(-2) # [bs, slen, slen] if disturb_keep_arr is not None: mask_pair *= BK.input_real(1. - disturb_keep_arr).unsqueeze( -1) # no predictions for the kept ones! # get all pair scores score_t = self.ps_node.paired_score( repr_t, repr_t, attn_t, maskp=mask_pair) # [bs, len_q, len_k, 2*R] # ----- # loss: normalize on which dim? # get the answers first if conf.pred_abs: answer_t = reldist_t.abs() # [1,2,3,...,PR] answer_t.clamp_( min=0, max=PR - 1) # [slen, slen], clip in range, distinguish using masks else: answer_t = BK.where( (reldist_t >= 0), reldist_t - 1, reldist_t + 2 * PR) # [1,2,3,...PR,-PR,-PR+1,...,-1] answer_t.clamp_( min=0, max=2 * PR - 1) # [slen, slen], clip in range, distinguish using masks # expand answer into idxes answer_hit_t = BK.zeros(BK.get_shape(answer_t) + [2 * PR]) # [len_q, len_k, 2*R] answer_hit_t.scatter_(-1, answer_t.unsqueeze(-1), 1.) answer_valid_t = ((reldist_t.abs() <= PR) & (reldist_t != 0)).float().unsqueeze( -1) # [bs, len_q, len_k, 1] answer_hit_t = answer_hit_t * mask_pair.unsqueeze( -1) * answer_valid_t # clear invalid ones; [bs, len_q, len_k, 2*R] # get losses sum(log(answer*prob)) # -- dim=-1 is standard 2*PR classification, dim=-2 usually have 2*PR candidates, but can be less at edges all_losses = [] for one_dim, one_lambda in zip([-1, -2], [conf.lambda_n1, conf.lambda_n2]): if one_lambda > 0.: # since currently there can be only one or zero correct answer logprob_t = BK.log_softmax(score_t, one_dim) # [bs, len_q, len_k, 2*R] sumlogprob_t = (logprob_t * answer_hit_t).sum( one_dim) # [bs, len_q, len_k||2*R] cur_dim_mask_t = (answer_hit_t.sum(one_dim) > 0.).float() # [bs, len_q, len_k||2*R] # loss cur_dim_loss = -(sumlogprob_t * cur_dim_mask_t).sum() cur_dim_count = cur_dim_mask_t.sum() # argmax and corr (any correct counts) _, cur_argmax_idxes = score_t.max(one_dim) cur_corrs = answer_hit_t.gather( one_dim, cur_argmax_idxes.unsqueeze( one_dim)) # [bs, len_q, len_k|1, 2*R|1] cur_dim_corr_count = cur_corrs.sum() # compile loss one_loss = LossHelper.compile_leaf_info( f"d{one_dim}", cur_dim_loss, cur_dim_count, loss_lambda=one_lambda, corr=cur_dim_corr_count) all_losses.append(one_loss) return self._compile_component_loss("orp", all_losses)
def fb(self, annotated_insts, scoring_expr_pack, training: bool, loss_factor: float): # depth constrain: <= sched_depth cur_depth_constrain = int(self.sched_depth.value) # run ags = [ BfsLinearAgenda.init_agenda(TdState, z, self.require_sg) for z in annotated_insts ] self.oracle_manager.refresh_insts(annotated_insts) self.searcher.refresh(scoring_expr_pack) self.searcher.go(ags) # collect local loss: credit assignment if self.train_force or self.train_ss: states = [] for ag in ags: for final_state in ag.local_golds: # todo(warn): remember to use depth_eff rather than depth # todo(warn): deprecated # if final_state.depth_eff > cur_depth_constrain: # continue states.append(final_state) logprobs_arc = [s.arc_score_slice for s in states] # no labeling scores for reduce operations logprobs_label = [ s.label_score_slice for s in states if s.label_score_slice is not None ] credits_arc, credits_label = None, None elif self.train_of: states = [] for ag in ags: for final_state in ag.ends: for s in final_state.get_path(True): states.append(s) logprobs_arc = [s.arc_score_slice for s in states] # no labeling scores for reduce operations logprobs_label = [ s.label_score_slice for s in states if s.label_score_slice is not None ] credits_arc, credits_label = None, None elif self.train_rl: logprobs_arc, logprobs_label, credits_arc, credits_label = [], [], [], [] for ag in ags: # todo(+2): need to check search failure? # todo(+2): ignoring labels when reducing or wrong-arc for final_state in ag.ends: # todo(warn): deprecated # if final_state.depth_eff > cur_depth_constrain: # continue one_credits_arc = [] one_credits_label = [] self.oracle_manager.set_losses(final_state) for s in final_state.get_path(True): _, _, delta_arc, delta_label = s.oracle_loss_cache logprobs_arc.append(s.arc_score_slice) if delta_arc > 0: # only blame arc one_credits_arc.append(-delta_arc) else: one_credits_arc.append(0) if delta_label > 0: logprobs_label.append(s.label_score_slice) one_credits_label.append(-delta_label) elif s.label_score_slice is not None: # not bad labeling logprobs_label.append(s.label_score_slice) one_credits_label.append(0) # TODO(+N): minus average may encourage bad moves? # balance # avg_arc = sum(one_credits_arc) / len(one_credits_arc) # avg_label = 0. if len(one_credits_label)==0 else sum(one_credits_label) / len(one_credits_label) baseline_arc = baseline_label = -0.5 credits_arc.extend(z - baseline_arc for z in one_credits_arc) credits_label.extend(z - baseline_label for z in one_credits_label) else: raise NotImplementedError("CANNOT get here!") # sum all local losses loss_zero = BK.zeros([]) if len(logprobs_arc) > 0: batched_logprobs_arc = SliceManager.combine_slices( logprobs_arc, None) loss_arc = (-BK.sum(batched_logprobs_arc)) if (credits_arc is None) \ else (-BK.sum(batched_logprobs_arc * BK.input_real(credits_arc))) else: loss_arc = loss_zero if len(logprobs_label) > 0: batched_logprobs_label = SliceManager.combine_slices( logprobs_label, None) loss_label = (-BK.sum(batched_logprobs_label)) if (credits_label is None) \ else (-BK.sum(batched_logprobs_label*BK.input_real(credits_label))) else: loss_label = loss_zero final_loss_sum = loss_arc + loss_label # divide loss by what? num_sent = len(annotated_insts) num_valid_arcs, num_valid_labels = len(logprobs_arc), len( logprobs_label) # num_valid_steps = len(states) if self.tconf.loss_div_step: final_loss = loss_arc / max(1, num_valid_arcs) + loss_label / max( 1, num_valid_labels) else: final_loss = final_loss_sum / num_sent # val_loss_arc = BK.get_value(loss_arc).item() val_loss_label = BK.get_value(loss_label).item() val_loss_sum = val_loss_arc + val_loss_label # cur_has_loss = 1 if ((num_valid_arcs + num_valid_labels) > 0) else 0 if training and cur_has_loss: BK.backward(final_loss, loss_factor) # todo(warn): make tok==steps for dividing in common.run info = { "sent": num_sent, "tok": num_valid_arcs, "valid_arc": num_valid_arcs, "valid_label": num_valid_labels, "loss_sum": val_loss_sum, "loss_arc": val_loss_arc, "loss_label": val_loss_label, "fb_all": 1, "fb_valid": cur_has_loss } return info
def loss(self, repr_t, orig_map: Dict, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # -- # specify input add_root_token = self.add_root_token # get from inputs if isinstance(repr_t, (list, tuple)): l2r_repr_t, r2l_repr_t = repr_t elif self.split_input_blm: l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1) else: l2r_repr_t, r2l_repr_t = repr_t, None # l2r and r2l word_t = BK.input_idx(orig_map["word"]) # [bs, rlen] slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long() # [bs, 1] if add_root_token: l2r_trg_t = BK.concat([word_t, slice_zero_t], -1) # pad one extra 0, [bs, rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, slice_zero_t, word_t[:, :-1]], -1) # pad two extra 0 at front, [bs, 2+rlen-1] else: l2r_trg_t = BK.concat( [word_t[:, 1:], slice_zero_t], -1 ) # pad one extra 0, but remove the first one, [bs, -1+rlen+1] r2l_trg_t = BK.concat( [slice_zero_t, word_t[:, :-1]], -1) # pad one extra 0 at front, [bs, 1+rlen-1] # gather the losses all_losses = [] pred_range_min, pred_range_max = max( 1, conf.min_pred_rank), self.pred_size - 1 if _tie_input_embeddings: pred_W = self.inputter_embed_node.E.E[:self. pred_size] # [PSize, Dim] else: pred_W = None # get input embeddings for output for pred_name, hid_node, pred_node, input_t, trg_t in \ zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred], [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]): if input_t is None: continue # hidden hid_t = hid_node( input_t) if hid_node else input_t # [bs, slen, hid] # pred: [bs, slen, Vsize] if _tie_input_embeddings: scores_t = BK.matmul(hid_t, pred_W.T) else: scores_t = pred_node(hid_t) # loss mask_t = ((trg_t >= pred_range_min) & (trg_t <= pred_range_max)).float() # [bs, slen] trg_t.clamp_(max=pred_range_max) # make it in range losses_t = BK.loss_nll(scores_t, trg_t) * mask_t # [bs, slen] _, argmax_idxes = scores_t.max(-1) # [bs, slen] corrs_t = (argmax_idxes == trg_t).float() * mask_t # [bs, slen] # compile leaf loss one_loss = LossHelper.compile_leaf_info(pred_name, losses_t.sum(), mask_t.sum(), loss_lambda=1., corr=corrs_t.sum()) all_losses.append(one_loss) return self._compile_component_loss("plm", all_losses)
def decode(self, inst: DocInstance): conf, model = self.conf, self.model model.refresh_batch(False) test_constrain_evt_types = self.test_constrain_evt_types with BK.no_grad_env(): # ===== # init the collections flattened_ef_ereprs, flattened_evt_ereprs = [], [] sent_offsets = [Constants.INT_PRAC_MIN]*len(inst.sents) # start offset of sent in the flattened erepr cur_offset = 0 # current offset all_ef_items, all_evt_items = [], [] # ===== # first basic run and ef and evt all_packs = model.bter.run([inst], training=False) for one_pack in all_packs: sent_insts, lexi_repr, enc_repr_ef, enc_repr_evt, mask_arr = one_pack mask_expr = BK.input_real(mask_arr) # ===== # store the enc reprs and sent offsets sent_size, sent_len = BK.get_shape(enc_repr_ef)[:2] assert BK.get_shape(enc_repr_ef) == BK.get_shape(enc_repr_evt) flattened_ef_ereprs.append(enc_repr_ef.view(sent_size*sent_len, -1)) # [cur_flatten_size, D] flattened_evt_ereprs.append(enc_repr_evt.view(sent_size*sent_len, -1)) for one_sent in sent_insts: sent_offsets[one_sent.sid] = cur_offset cur_offset += sent_len # ===== lkrc = not conf.dec_debug_mode # lookup.ret_copy? # ===== # ef if conf.lookup_ef: ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \ model._lookup_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, model.ef_extractor, ret_copy=lkrc) else: ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \ model._inference_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, model.ef_extractor, model.ef_creator) # collect all valid ones all_ef_items.extend(ef_items[BK.get_value(ef_valid_mask).astype(np.bool)]) # event if conf.lookup_evt: evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \ model._lookup_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, model.evt_extractor, ret_copy=lkrc) else: evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \ model._inference_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, model.evt_extractor, model.evt_creator) # collect all valid ones if test_constrain_evt_types is None: all_evt_items.extend(evt_items[BK.get_value(evt_valid_mask).astype(np.bool)]) else: all_evt_items.extend([z for z in evt_items[BK.get_value(evt_valid_mask).astype(np.bool)] if z.type in test_constrain_evt_types]) # ==== # cross-sentence pairwise arg score # flattened all enc: [Offset, D] flattened_ef_enc_repr, flattened_evt_enc_repr = BK.concat(flattened_ef_ereprs, 0), BK.concat(flattened_evt_ereprs, 0) # sort by position in doc all_ef_items.sort(key=lambda x: x.mention.hard_span.position(True)) all_evt_items.sort(key=lambda x: x.mention.hard_span.position(True)) if not conf.dec_debug_mode: # todo(note): delete origin links! for z in all_ef_items: if z is not None: z.links.clear() for z in all_evt_items: if z is not None: z.links.clear() # get other info # todo(note): currently all using head word all_ef_offsets = BK.input_idx([sent_offsets[x.mention.hard_span.sid]+x.mention.hard_span.head_wid for x in all_ef_items]) all_evt_offsets = BK.input_idx([sent_offsets[x.mention.hard_span.sid]+x.mention.hard_span.head_wid for x in all_evt_items]) all_ef_lab_idxes = BK.input_idx([model.ef_extractor.hlidx2idx(x.type_idx) for x in all_ef_items]) all_evt_lab_idxes = BK.input_idx([model.evt_extractor.hlidx2idx(x.type_idx) for x in all_evt_items]) # score all the pairs (with mini-batch) mini_batch_size = conf.score_mini_batch arg_linker = model.arg_linker all_logprobs = BK.zeros([len(all_ef_items), len(all_evt_items), arg_linker.num_label]) for bidx_ef in range(0, len(all_ef_items), mini_batch_size): cur_ef_enc_repr = flattened_ef_enc_repr[all_ef_offsets[bidx_ef:bidx_ef+mini_batch_size]].unsqueeze(0) cur_ef_lab_idxes = all_ef_lab_idxes[bidx_ef:bidx_ef+mini_batch_size].unsqueeze(0) for bidx_evt in range(0, len(all_evt_items), mini_batch_size): cur_evt_enc_repr = flattened_evt_enc_repr[all_evt_offsets[bidx_evt:bidx_evt+mini_batch_size]].unsqueeze(0) cur_evt_lab_idxes = all_evt_lab_idxes[bidx_evt:bidx_evt + mini_batch_size].unsqueeze(0) all_logprobs[bidx_ef:bidx_ef+mini_batch_size,bidx_evt:bidx_evt+mini_batch_size] = \ arg_linker.predict(cur_ef_enc_repr, cur_evt_enc_repr, cur_ef_lab_idxes, cur_evt_lab_idxes, ret_full_logprobs=True).squeeze(0) all_logprobs_arr = BK.get_value(all_logprobs) # ===== # then decode them all using the scores self.arg_decode(inst, all_ef_items, all_evt_items, all_logprobs_arr) # ===== # assign and return num_pred_arg = 0 for one_sent in inst.sents: one_sent.pred_entity_fillers.clear() one_sent.pred_events.clear() for z in all_ef_items: inst.sents[z.mention.hard_span.sid].pred_entity_fillers.append(z) for z in all_evt_items: inst.sents[z.mention.hard_span.sid].pred_events.append(z) num_pred_arg += len(z.links) info = {"doc": 1, "sent": len(inst.sents), "token": sum(s.length-1 for s in inst.sents), "p_ef": len(all_ef_items), "p_evt": len(all_evt_items), "p_arg": num_pred_arg} return info
def _viterbi_decode(self, feats, mask): """ input: feats: (batch, seq_len, self.tag_size+2) mask: (batch, seq_len) output: decode_idx: (batch, seq_len) decoded sequence path_score: (batch, 1) corresponding score for each sequence (to be implementated) """ batch_size = feats.size(0) seq_len = feats.size(1) tag_size = feats.size(2) assert (tag_size == self.tagset_size + 2) ## calculate sentence length for each sentence length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() ## mask to (seq_len, batch_size) mask = mask.transpose(1, 0).contiguous() ins_num = seq_len * batch_size ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) feats = feats.transpose(1, 0).contiguous().view( ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) ## need to consider start scores = feats + self.transitions.view(1, tag_size, tag_size).expand( ins_num, tag_size, tag_size) scores = scores.view(seq_len, batch_size, tag_size, tag_size) # build iter seq_iter = enumerate(scores) ## record the position of best score back_points = list() partition_history = list() ## reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask # mask = (1 - mask.long()).byte() mask = (1 - mask.long()).bool() _, inivalues = next( seq_iter) # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone().view( batch_size, tag_size) # bat_size * to_target_size # print "init part:",partition.size() partition_history.append(partition) # iter over last scores for idx, cur_values in seq_iter: # previous to_target is current from_target # partition: previous results log(exp(from_target)), #(batch_size * from_target) # cur_values: batch_size * from_target * to_target cur_values = cur_values + partition.contiguous().view( batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG # print "cur value:", cur_values.size() partition, cur_bp = torch.max(cur_values, 1) # print "partsize:",partition.size() # exit(0) # print partition # print cur_bp # print "one best, ",idx partition_history.append(partition) ## cur_bp: (batch_size, tag_size) max source score position in current tag ## set padded label as 0, which will be filtered in post processing cur_bp.masked_fill_( mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) back_points.append(cur_bp) # exit(0) ### add score to final STOP_TAG partition_history = torch.cat(partition_history, 0).view( seq_len, batch_size, -1).transpose(1, 0).contiguous() ## (batch_size, seq_len. tag_size) ### get the last position for each setences, and select the last partitions using gather() last_position = length_mask.view(batch_size, 1, 1).expand( batch_size, 1, tag_size) - 1 last_partition = torch.gather(partition_history, 1, last_position).view( batch_size, tag_size, 1) ### calculate the score from last partition to end state (and then select the STOP_TAG from it) last_values = last_partition.expand( batch_size, tag_size, tag_size) + self.transitions.view( 1, tag_size, tag_size).expand(batch_size, tag_size, tag_size) _, last_bp = torch.max(last_values, 1) # pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() # if self.gpu: # pad_zero = pad_zero.cuda() pad_zero = BK.zeros((batch_size, tag_size)).long() back_points.append(pad_zero) back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) ## select end ids in STOP_TAG pointer = last_bp[:, STOP_TAG] insert_last = pointer.contiguous().view(batch_size, 1, 1).expand( batch_size, 1, tag_size) back_points = back_points.transpose(1, 0).contiguous() ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values # print "lp:",last_position # print "il:",insert_last back_points.scatter_(1, last_position, insert_last) # print "bp:",back_points # exit(0) back_points = back_points.transpose(1, 0).contiguous() ## decode from the end, padded position ids are 0, which will be filtered if following evaluation # decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) # if self.gpu: # decode_idx = decode_idx.cuda() decode_idx = BK.zeros((seq_len, batch_size)).long() decode_idx[-1] = pointer.detach() for idx in range(len(back_points) - 2, -1, -1): pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) decode_idx[idx] = pointer.detach().view(batch_size) path_score = None decode_idx = decode_idx.transpose(1, 0) return path_score, decode_idx
def predict(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf bsize, slen = BK.get_shape(input_mask) bsize_arange_t_1d = BK.arange_idx(bsize) # [*] bsize_arange_t_2d = bsize_arange_t_1d.unsqueeze(-1) # [*, 1] beam_size = conf.beam_size # prepare things with an extra beam dimension beam_input_expr, beam_input_mask = input_expr.unsqueeze(-3).expand(-1, beam_size, -1, -1).contiguous(), \ input_mask.unsqueeze(-2).expand(-1, beam_size, -1).contiguous() # [*, beam, slen, D?] # ----- # recurrent states beam_hard_coverage = BK.zeros([bsize, beam_size, slen]) # [*, beam, slen] # tuple([*, beam, D], ) beam_prev_state = [z.unsqueeze(-2).expand(-1, beam_size, -1) for z in self.rnn_unit.zero_init_hidden(bsize)] # frozen after reach eos beam_noneos = 1.-BK.zeros([bsize, beam_size]) # [*, beam] beam_logprobs = BK.zeros([bsize, beam_size]) # [*, beam], sum of logprobs beam_logprobs_paths = BK.zeros([bsize, beam_size, 0]) # [*, beam, step] beam_tok_paths = BK.zeros([bsize, beam_size, 0]).long() beam_lab_paths = BK.zeros([bsize, beam_size, 0]).long() # ----- for cstep in range(conf.max_step): # get things of [*, beam, beam] sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state = \ self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, None, None, beam_size) sel_logprobs = sel_tok_logprobs + sel_lab_logprobs # [*, beam, beam] if cstep == 0: # special for the first step, only select for the first element cur_selections = BK.arange_idx(beam_size).unsqueeze(0).expand(bsize, beam_size) # [*, beam] else: # then select the topk in beam*beam (be careful about the frozen ones!!) beam_noneos_3d = beam_noneos.unsqueeze(-1) # eos can only followed by eos sel_tok_idxes *= beam_noneos_3d.long() sel_lab_idxes *= beam_noneos_3d.long() # numeric tricks to keep the frozen ones ([0] with 0. score, [1:] with -inf scores) sel_logprobs *= beam_noneos_3d tmp_exclude_mask = 1. - beam_noneos_3d.expand_as(sel_logprobs) tmp_exclude_mask[:, :, 0] = 0. sel_logprobs += tmp_exclude_mask * Constants.REAL_PRAC_MIN # select for topk topk_logprobs = (beam_noneos * beam_logprobs).unsqueeze(-1) + sel_logprobs _, cur_selections = topk_logprobs.view([bsize, -1]).topk(beam_size, dim=-1, sorted=True) # [*, beam] # read and write the selections # gathering previous ones cur_sel_previ = cur_selections // beam_size # [*, beam] prev_hard_coverage = beam_hard_coverage[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_noneos = beam_noneos[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_logprobs = beam_logprobs[bsize_arange_t_2d, cur_sel_previ] # [*, beam] prev_logprobs_paths = beam_logprobs_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] prev_tok_paths = beam_tok_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] prev_lab_paths = beam_lab_paths[bsize_arange_t_2d, cur_sel_previ] # [*, beam, step] # prepare new ones cur_sel_newi = cur_selections % beam_size new_tok_idxes = sel_tok_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_lab_idxes = sel_lab_idxes[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_logprobs = sel_logprobs[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] # [*, beam] new_prev_state = [z[bsize_arange_t_2d, cur_sel_previ, cur_sel_newi] for z in next_state] # [*, beam, ~] # update prev_hard_coverage[bsize_arange_t_2d, BK.arange_idx(beam_size).unsqueeze(0), new_tok_idxes] += 1. beam_hard_coverage = prev_hard_coverage beam_prev_state = new_prev_state beam_noneos = prev_noneos * (new_tok_idxes!=0).float() beam_logprobs = prev_logprobs + new_logprobs beam_logprobs_paths = BK.concat([prev_logprobs_paths, new_logprobs.unsqueeze(-1)], -1) beam_tok_paths = BK.concat([prev_tok_paths, new_tok_idxes.unsqueeze(-1)], -1) beam_lab_paths = BK.concat([prev_lab_paths, new_lab_idxes.unsqueeze(-1)], -1) # finally force an extra eos step to get ending tok-logprob (no need to update other things) final_eos_idxes = BK.zeros([bsize, beam_size]).long() _, eos_logprobs, _, _, _, _ = self._step(beam_input_expr, beam_input_mask, beam_hard_coverage, beam_prev_state, final_eos_idxes, final_eos_idxes, None) beam_logprobs += eos_logprobs.squeeze(-1) * beam_noneos # [*, beam] # select and return the best one beam_tok_valids = (beam_tok_paths > 0).float() # [*, beam, steps] final_scores = beam_logprobs / ((beam_tok_valids.sum(-1) + 1.) ** conf.len_alpha) # [*, beam] _, best_beam_idx = final_scores.max(-1) # [*] # ----- # prepare returns; cut by max length: [*, all_step] -> [*, max_step] ret0_valid_mask = beam_tok_valids[bsize_arange_t_1d, best_beam_idx] cur_max_step = ret0_valid_mask.long().sum(-1).max().item() ret_valid_mask = ret0_valid_mask[:, :cur_max_step] ret_logprobs = beam_logprobs_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] ret_tok_idxes = beam_tok_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] ret_lab_idxes = beam_lab_paths[bsize_arange_t_1d, best_beam_idx][:, :cur_max_step] # embeddings ret_lab_embeds = self._hl_lookup(ret_lab_idxes) return ret_logprobs, ret_tok_idxes, ret_valid_mask, ret_lab_idxes, ret_lab_embeds
def zeros(self, batch): return BK.zeros((batch, self.dim))
def inference_on_batch(self, insts: List[DocInstance], **kwargs): self.refresh_batch(False) test_constrain_evt_types = self.test_constrain_evt_types ndoc, nsent = len(insts), 0 iconf = self.conf.iconf with BK.no_grad_env(): # splitting into buckets all_packs = self.bter.run(insts, training=False) for one_pack in all_packs: # ===== # predict sent_insts, lexi_repr, enc_repr_ef, enc_repr_evt, mask_arr = one_pack nsent += len(sent_insts) mask_expr = BK.input_real(mask_arr) # entity and filler if iconf.lookup_ef: ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \ self._lookup_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, self.ef_extractor, ret_copy=True) elif iconf.pred_ef: ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, ef_lab_embeds = \ self._inference_mentions(sent_insts, lexi_repr, enc_repr_ef, mask_expr, self.ef_extractor, self.ef_creator) else: ef_items = [[] for _ in range(len(sent_insts))] ef_valid_mask = BK.zeros((len(sent_insts), 0)) ef_widxes = ef_lab_idxes = ef_lab_embeds = None # event if iconf.lookup_evt: evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \ self._lookup_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, self.evt_extractor, ret_copy=True) else: evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, evt_lab_embeds = \ self._inference_mentions(sent_insts, lexi_repr, enc_repr_evt, mask_expr, self.evt_extractor, self.evt_creator) # arg if iconf.pred_arg: # todo(note): for this step of decoding, we only consider inner-sentence pairs # todo(note): inplaced self._inference_args(ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt) # ===== # assign for one_sent_inst, one_ef_items, one_ef_valid, one_evt_items, one_evt_valid in \ zip(sent_insts, ef_items, BK.get_value(ef_valid_mask), evt_items, BK.get_value(evt_valid_mask)): # entity and filler one_ef_items = [z for z,va in zip(one_ef_items, one_ef_valid) if (va and z is not None)] one_sent_inst.pred_entity_fillers = one_ef_items # event one_evt_items = [z for z,va in zip(one_evt_items, one_evt_valid) if (va and z is not None)] if test_constrain_evt_types is not None: one_evt_items = [z for z in one_evt_items if z.type in test_constrain_evt_types] # ===== # todo(note): special rule (actually a simple rule based extender) if iconf.expand_evt_compound: for one_evt in one_evt_items: one_hard_span = one_evt.mention.hard_span sid, hwid, _ = one_hard_span.position(True) assert one_hard_span.length == 1 # currently no way to predict more if hwid+1 < one_sent_inst.length: if one_sent_inst.uposes.vals[hwid]=="VERB" and one_sent_inst.ud_heads.vals[hwid+1]==hwid \ and one_sent_inst.ud_labels.vals[hwid+1]=="compound": one_hard_span.length += 1 # ===== one_sent_inst.pred_events = one_evt_items return {"doc": ndoc, "sent": nsent}