def _fb_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt, margin): # get the gold idxes arg_linker = self.arg_linker bsize, len_ef = ef_items.shape bsize2, len_evt = evt_items.shape assert bsize == bsize2 gold_idxes = np.zeros([bsize, len_ef, len_evt], dtype=np.long) for one_gold_idxes, one_ef_items, one_evt_items in zip(gold_idxes, ef_items, evt_items): # todo(note): check each pair for ef_idx, one_ef in enumerate(one_ef_items): if one_ef is None: continue role_map = {id(z.evt): z.role_idx for z in one_ef.links} # todo(note): since we get the original linked ones for evt_idx, one_evt in enumerate(one_evt_items): pairwise_role_hlidx = role_map.get(id(one_evt)) if pairwise_role_hlidx is not None: pairwise_role_idx = arg_linker.hlidx2idx(pairwise_role_hlidx) assert pairwise_role_idx > 0 one_gold_idxes[ef_idx, evt_idx] = pairwise_role_idx # get loss repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2) # [*, len-ef, D] repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2) # [*, len-evt, D] if np.prod(gold_idxes.shape) == 0: # no instances! return [[BK.zeros([]), BK.zeros([])]] else: gold_idxes_t = BK.input_idx(gold_idxes) return arg_linker.loss(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask, gold_idxes_t, margin)
def _inference_args(self, ef_items, ef_widxes, ef_valid_mask, ef_lab_idxes, enc_repr_ef, evt_items, evt_widxes, evt_valid_mask, evt_lab_idxes, enc_repr_evt): arg_linker = self.arg_linker repr_ef = BK.gather_first_dims(enc_repr_ef, ef_widxes, -2) # [*, len-ef, D] repr_evt = BK.gather_first_dims(enc_repr_evt, evt_widxes, -2) # [*, len-evt, D] role_logprobs, role_predictions = arg_linker.predict(repr_ef, repr_evt, ef_lab_idxes, evt_lab_idxes, ef_valid_mask, evt_valid_mask) # add them inplaced roles_arr = BK.get_value(role_predictions) # [*, len-ef, len-evt] logprobs_arr = BK.get_value(role_logprobs) for bidx, one_roles_arr in enumerate(roles_arr): one_ef_items, one_evt_items = ef_items[bidx], evt_items[bidx] # ===== # todo(note): delete origin links! for z in one_ef_items: if z is not None: z.links.clear() for z in one_evt_items: if z is not None: z.links.clear() # ===== one_logprobs = logprobs_arr[bidx] for ef_idx, one_ef in enumerate(one_ef_items): if one_ef is None: continue for evt_idx, one_evt in enumerate(one_evt_items): if one_evt is None: continue one_role_idx = int(one_roles_arr[ef_idx, evt_idx]) if one_role_idx > 0: # link this_hlidx = arg_linker.idx2hlidx(one_role_idx) one_evt.add_arg(one_ef, role=str(this_hlidx), role_idx=this_hlidx, score=float(one_logprobs[ef_idx, evt_idx]))
def _enc(self, input_lexi, input_expr, input_mask, sel_idxes): if self.dmxnn: bsize, slen = BK.get_shape(input_mask) if sel_idxes is None: sel_idxes = BK.arange_idx(slen).unsqueeze( 0) # select all, [1, slen] ncand = BK.get_shape(sel_idxes, -1) # enc_expr aug with PE rel_dist = BK.arange_idx(slen).unsqueeze(0).unsqueeze( 0) - sel_idxes.unsqueeze(-1) # [*, ?, slen] pe_embeds = self.posi_embed(rel_dist) # [*, ?, slen, Dpe] aug_enc_expr = BK.concat([ pe_embeds.expand(bsize, -1, -1, -1), input_expr.unsqueeze(1).expand(-1, ncand, -1, -1) ], -1) # [*, ?, slen, D+Dpe] # [*, ?, slen, Denc] hidden_expr = self.e_encoder( aug_enc_expr.view(bsize * ncand, slen, -1), input_mask.unsqueeze(1).expand(-1, ncand, -1).contiguous().view( bsize * ncand, slen)) hidden_expr = hidden_expr.view(bsize, ncand, slen, -1) # dynamic max-pooling (dist<0, dist=0, dist>0) NEG = Constants.REAL_PRAC_MIN mp_hiddens = [] mp_masks = [rel_dist < 0, rel_dist == 0, rel_dist > 0] for mp_mask in mp_masks: float_mask = mp_mask.float() * input_mask.unsqueeze( -2) # [*, ?, slen] valid_mask = (float_mask.sum(-1) > 0.).float().unsqueeze( -1) # [*, ?, 1] mask_neg_val = ( 1. - float_mask).unsqueeze(-1) * NEG # [*, ?, slen, 1] # todo(+2): or do we simply multiply mask? mp_hid0 = (hidden_expr + mask_neg_val).max(-2)[0] mp_hid = mp_hid0 * valid_mask # [*, ?, Denc] mp_hiddens.append(self.special_drop(mp_hid)) # mp_hiddens.append(mp_hid) final_hiddens = mp_hiddens else: hidden_expr = self.e_encoder(input_expr, input_mask) # [*, slen, D'] if sel_idxes is None: hidden_expr1 = hidden_expr else: hidden_expr1 = BK.gather_first_dims(hidden_expr, sel_idxes, -2) # [*, ?, D'] final_hiddens = [self.special_drop(hidden_expr1)] if self.lab_f_use_lexi: final_hiddens.append( BK.gather_first_dims(input_lexi, sel_idxes, -2)) # [*, ?, DLex] ret_expr = self.lab_f(final_hiddens) # [*, ?, DLab] return ret_expr
def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr): mask_idxes, mask_valids = BK.mask2idx( BK.input_real(pred_mask_repl_arr)) # [bsize, ?] if BK.get_shape(mask_idxes, -1) == 0: # no loss zzz = BK.zeros([]) return [[zzz, zzz, zzz]] else: target_reprs = BK.gather_first_dims(repr_t, mask_idxes, 1) # [bsize, ?, *] target_hids = self.hid_layer(target_reprs) target_scores = self.pred_layer(target_hids) # [bsize, ?, V] pred_idx_t = BK.input_idx(pred_idx_arr) # [bsize, slen] target_idx_t = pred_idx_t.gather(-1, mask_idxes) # [bsize, ?] target_idx_t[(mask_valids < 1.)] = 0 # make sure invalid ones in range # get loss pred_losses = BK.loss_nll(target_scores, target_idx_t) # [bsize, ?] pred_loss_sum = (pred_losses * mask_valids).sum() pred_loss_count = mask_valids.sum() # argmax _, argmax_idxes = target_scores.max(-1) pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids pred_corr_count = pred_corrs.sum() return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
def _step(self, input_expr, input_mask, hard_coverage, prev_state, force_widx, force_lidx, free_beam_size): conf = self.conf free_mode = (force_widx is None) prev_state_h = prev_state[0] # ===== # collect att scores key_up = self.affine_k([input_expr, hard_coverage.unsqueeze(-1)]) # [*, slen, h] query_up = self.affine_q([self.repos.unsqueeze(0), prev_state_h.unsqueeze(-2)]) # [*, R, h] orig_scores = BK.matmul(key_up, query_up.transpose(-2, -1)) # [*, slen, R] orig_scores += (1.-input_mask).unsqueeze(-1) * Constants.REAL_PRAC_MIN # [*, slen, R] # first maximum across the R dim (this step is hard max) maxr_scores, maxr_idxes = orig_scores.max(-1) # [*, slen] if conf.zero_eos_score: # use mask to make it able to be backward tmp_mask = BK.constants(BK.get_shape(maxr_scores), 1.) tmp_mask.index_fill_(-1, BK.input_idx(0), 0.) maxr_scores *= tmp_mask # then select over the slen dim (this step is prob based) maxr_logprobs = BK.log_softmax(maxr_scores) # [*, slen] if free_mode: cur_beam_size = min(free_beam_size, BK.get_shape(maxr_logprobs, -1)) sel_tok_logprobs, sel_tok_idxes = maxr_logprobs.topk(cur_beam_size, dim=-1, sorted=False) # [*, beam] else: sel_tok_idxes = force_widx.unsqueeze(-1) # [*, 1] sel_tok_logprobs = maxr_logprobs.gather(-1, sel_tok_idxes) # [*, 1] # then collect the info and perform labeling lf_input_expr = BK.gather_first_dims(input_expr, sel_tok_idxes, -2) # [*, ?, ~] lf_coverage = hard_coverage.gather(-1, sel_tok_idxes).unsqueeze(-1) # [*, ?, 1] lf_repos = self.repos[maxr_idxes.gather(-1, sel_tok_idxes)] # [*, ?, ~] # todo(+3): using soft version? lf_prev_state = prev_state_h.unsqueeze(-2) # [*, 1, ~] lab_hid_expr = self.lab_f([lf_input_expr, lf_coverage, lf_repos, lf_prev_state]) # [*, ?, ~] # final predicting labels # todo(+N): here we select only max at labeling part, only beam at previous one if free_mode: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, None) # [*, ?] else: sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(lab_hid_expr, force_lidx.unsqueeze(-1)) # no lab-logprob (*=0) for eos (sel_tok==0) sel_lab_logprobs *= (sel_tok_idxes>0).float() # compute next-state [*, ?, ~] # todo(note): here we flatten the first two dims tmp_rnn_dims = BK.get_shape(sel_tok_idxes) + [-1] tmp_rnn_input = BK.concat([lab_hid_expr, sel_lab_embeds], -1) tmp_rnn_input = tmp_rnn_input.view(-1, BK.get_shape(tmp_rnn_input, -1)) tmp_rnn_hidden = [z.unsqueeze(-2).expand(tmp_rnn_dims).contiguous().view(-1, BK.get_shape(z, -1)) for z in prev_state] # [*, ?, ?, D] next_state = self.rnn_unit(tmp_rnn_input, tmp_rnn_hidden, None) next_state = [z.view(tmp_rnn_dims) for z in next_state] return sel_tok_idxes, sel_tok_logprobs, sel_lab_idxes, sel_lab_logprobs, sel_lab_embeds, next_state
def _exclude_nil(self, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs=None, sel_items_arr=None): # todo(note): assure that nil is 0 sel_valid_mask = sel_valid_mask * (sel_lab_idxes != 0).float() # not inplaced # idx on idx s2_idxes, s2_valid_mask = BK.mask2idx(sel_valid_mask) sel_idxes = sel_idxes.gather(-1, s2_idxes) sel_valid_mask = s2_valid_mask sel_lab_idxes = sel_lab_idxes.gather(-1, s2_idxes) sel_lab_embeds = BK.gather_first_dims(sel_lab_embeds, s2_idxes, -2) sel_logprobs = None if sel_logprobs is None else sel_logprobs.gather( -1, s2_idxes) sel_items_arr = None if sel_items_arr is None \ else sel_items_arr[np.arange(len(sel_items_arr))[:, np.newaxis], BK.get_value(s2_idxes)] return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs, sel_items_arr
def forward_batch(self, batched_ids: List, batched_starts: List, batched_typeids: List, training: bool, other_inputs: List[List] = None): conf = self.bconf tokenizer = self.tokenizer PAD_IDX = tokenizer.pad_token_id MASK_IDX = tokenizer.mask_token_id CLS_IDX = tokenizer.cls_token_id SEP_IDX = tokenizer.sep_token_id if other_inputs is None: other_inputs = [] # ===== # batch: here add CLS and SEP bsize = len(batched_ids) max_len = max(len(z) for z in batched_ids) + 2 # plus [CLS] and [SEP] input_shape = (bsize, max_len) # first collect on CPU input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64) input_ids_arr[:, 0] = CLS_IDX input_mask_arr = np.full(input_shape, 0, dtype=np.float32) input_is_start_arr = np.full(input_shape, 0, dtype=np.int64) input_typeids = None if batched_typeids is None else np.full( input_shape, 0, dtype=np.int64) other_input_arrs = [ np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs ] if conf.bert2_retinc_cls: # act as the ROOT word input_is_start_arr[:, 0] = 1 training_mask_rate = conf.bert2_training_mask_rate if training else 0. self_sample_stream = self.random_sample_stream for bidx in range(bsize): cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx] cur_end = len(cur_ids) + 2 # plus CLS and SEP if training_mask_rate > 0.: # input dropout input_ids_arr[bidx, 1:cur_end] = [ (MASK_IDX if next(self_sample_stream) < training_mask_rate else z) for z in cur_ids ] + [SEP_IDX] else: input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX] input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts input_mask_arr[bidx, :cur_end] = 1. if batched_typeids is not None and batched_typeids[ bidx] is not None: input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx] for one_other_input_arr, one_other_input_list in zip( other_input_arrs, other_inputs): one_other_input_arr[bidx, 1:cur_end - 1] = one_other_input_list[bidx] # arr to tensor input_ids_t = BK.input_idx(input_ids_arr) input_mask_t = BK.input_real(input_mask_arr) input_is_start_t = BK.input_idx(input_is_start_arr) input_typeid_t = None if input_typeids is None else BK.input_idx( input_typeids) other_input_ts = [BK.input_idx(z) for z in other_input_arrs] # ===== # forward (maybe need multiple times to fit maxlen constraint) MAX_LEN = 510 # save two for [CLS] and [SEP] BACK_LEN = 100 # for splitting cases, still remaining some of previous sub-tokens for context if max_len <= MAX_LEN: # directly once final_outputs = self.forward_features( input_ids_t, input_mask_t, input_typeid_t, other_input_ts) # [bs, slen, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t.float()) # [bsize, ?] else: all_outputs = [] cur_sub_idx = 0 slice_size = [bsize, 1] slice_cls, slice_sep = BK.constants(slice_size, CLS_IDX, dtype=BK.int64), BK.constants( slice_size, SEP_IDX, dtype=BK.int64) while cur_sub_idx < max_len - 1: # minus 1 to ignore ending SEP cur_slice_start = max(1, cur_sub_idx - BACK_LEN) cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1) cur_input_ids_t = BK.concat([ slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end], slice_sep ], 1) # here we simply extend extra original masks cur_input_mask_t = input_mask_t[:, cur_slice_start - 1:cur_slice_end + 1] cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:, cur_slice_start - 1: cur_slice_end + 1] cur_other_input_ts = [ z[:, cur_slice_start - 1:cur_slice_end + 1] for z in other_input_ts ] cur_outputs = self.forward_features(cur_input_ids_t, cur_input_mask_t, cur_input_typeid_t, cur_other_input_ts) # only include CLS in the first run, no SEP included if cur_sub_idx == 0: # include CLS, exclude SEP all_outputs.append(cur_outputs[:, :-1]) else: # include only new ones, discard BACK ones, exclude CLS, SEP all_outputs.append(cur_outputs[:, cur_sub_idx - cur_slice_start + 1:-1]) zwarn( f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] " f"for all-len={max_len}") cur_sub_idx = cur_slice_end final_outputs = BK.concat(all_outputs, 1) # [bs, max_len-1, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t[:, :-1].float()) # [bsize, ?] start_expr = BK.gather_first_dims(final_outputs, start_idxes, 1) # [bsize, ?, *...] return start_expr, start_masks # [bsize, ?, ...], [bsize, ?]
def loss(self, repr_ts, input_erase_mask_arr, orig_map: Dict, active_hid=True, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # prepare idxes for the masked ones if self.add_root_token: # offset for the special root added in embedder mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr), padding_idx=-1) # [bsize, ?] repr_mask_idxes = mask_idxes + 1 mask_idxes.clamp_(min=0) else: mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr)) # [bsize, ?] repr_mask_idxes = mask_idxes # get the losses if BK.get_shape(mask_idxes, -1) == 0: # no loss return self._compile_component_loss("mlm", []) else: if not isinstance(repr_ts, (List, Tuple)): repr_ts = [repr_ts] target_word_scores, target_pos_scores = [], [] target_pos_scores = None # todo(+N): for simplicity, currently ignore this one!! for layer_idx in conf.loss_layers: # calculate scores target_reprs = BK.gather_first_dims(repr_ts[layer_idx], repr_mask_idxes, 1) # [bsize, ?, *] if self.hid_layer and active_hid: # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside! target_hids = self.hid_layer(target_reprs) else: target_hids = target_reprs if _tie_input_embeddings: pred_W = self.inputter_word_node.E.E[:self. pred_word_size] # [PSize, Dim] target_word_scores.append(BK.matmul( target_hids, pred_W.T)) # List[bsize, ?, Vw] else: target_word_scores.append(self.pred_word_layer( target_hids)) # List[bsize, ?, Vw] # gather the losses all_losses = [] for pred_name, target_scores, loss_lambda, range_min, range_max in \ zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos], [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]): if loss_lambda > 0.: seq_idx_t = BK.input_idx( orig_map[pred_name]) # [bsize, slen] target_idx_t = seq_idx_t.gather(-1, mask_idxes) # [bsize, ?] ranged_mask_valids = mask_valids * ( target_idx_t >= range_min).float() * ( target_idx_t <= range_max).float() target_idx_t[(ranged_mask_valids < 1.)] = 0 # make sure invalid ones in range # calculate for each layer all_layer_losses, all_layer_scores = [], [] for one_layer_idx, one_target_scores in enumerate( target_scores): # get loss: [bsize, ?] one_pred_losses = BK.loss_nll( one_target_scores, target_idx_t) * conf.loss_weights[one_layer_idx] all_layer_losses.append(one_pred_losses) # get scores one_pred_scores = BK.log_softmax( one_target_scores, -1) * conf.loss_weights[one_layer_idx] all_layer_scores.append(one_pred_scores) # combine all layers pred_losses = self.loss_comb_f(all_layer_losses) pred_loss_sum = (pred_losses * ranged_mask_valids).sum() pred_loss_count = ranged_mask_valids.sum() # argmax _, argmax_idxes = self.score_comb_f(all_layer_scores).max( -1) pred_corrs = (argmax_idxes == target_idx_t).float() * ranged_mask_valids pred_corr_count = pred_corrs.sum() # compile leaf loss r_loss = LossHelper.compile_leaf_info( pred_name, pred_loss_sum, pred_loss_count, loss_lambda=loss_lambda, corr=pred_corr_count) all_losses.append(r_loss) return self._compile_component_loss("mlm", all_losses)