def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr): mask_idxes, mask_valids = BK.mask2idx( BK.input_real(pred_mask_repl_arr)) # [bsize, ?] if BK.get_shape(mask_idxes, -1) == 0: # no loss zzz = BK.zeros([]) return [[zzz, zzz, zzz]] else: target_reprs = BK.gather_first_dims(repr_t, mask_idxes, 1) # [bsize, ?, *] target_hids = self.hid_layer(target_reprs) target_scores = self.pred_layer(target_hids) # [bsize, ?, V] pred_idx_t = BK.input_idx(pred_idx_arr) # [bsize, slen] target_idx_t = pred_idx_t.gather(-1, mask_idxes) # [bsize, ?] target_idx_t[(mask_valids < 1.)] = 0 # make sure invalid ones in range # get loss pred_losses = BK.loss_nll(target_scores, target_idx_t) # [bsize, ?] pred_loss_sum = (pred_losses * mask_valids).sum() pred_loss_count = mask_valids.sum() # argmax _, argmax_idxes = target_scores.max(-1) pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids pred_corr_count = pred_corrs.sum() return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
def lookup(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf bsize = len(insts) # first get gold/input info, also multiple valid-masks gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h( insts) # step 1: no selection, simply forward using gold_masks sel_idxes, sel_valid_mask = BK.mask2idx(gold_masks) # [*, max-count] sel_gold_idxes = gold_idxes.gather(-1, sel_idxes) sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes) # todo(+N): only get items by head position! _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value( sel_idxes) sel_items = gold_items_arr[_tmp_i0, _tmp_i1] # [*, mc] sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1] # step 2: encoding and labeling sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: sel_lab_idxes = sel_gold_idxes sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim]) ret_items = sel_items # dim-1==0 else: # sel_hid_exprs = self._enc(input_expr, input_mask, sel_idxes) # [*, mc, DLab] sel_lab_idxes = sel_gold_idxes sel_lab_embeds = self.hl.lookup( sel_lab_idxes) # todo(note): here no softlookup? ret_items = sel_items # second type if self.use_secondary_type: sel2_lab_idxes = sel_gold_idxes2 sel2_lab_embeds = self.hl.lookup( sel2_lab_idxes) # todo(note): here no softlookup? sel2_valid_mask = (sel2_lab_idxes > 0).float() # combine the two if sel2_lab_idxes.sum().item( ) > 0: # if there are any gold sectypes ret_items = np.concatenate([ret_items, sel2_items], -1) # [*, mc*2] sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat( [sel_valid_mask, sel2_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat( [sel_lab_embeds, sel2_lab_embeds], -2) # step 3: exclude nil assuming no deliberate nil in gold/inputs if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items) # step 4: return # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] # mask out invalid items with None ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def _pmask2idxes(self, pred_mask): orig_shape = BK.get_shape(pred_mask) dim_type = orig_shape[-1] flattened_mask = pred_mask.view(orig_shape[:-2] + [-1]) # [*, slen*L] f_idxes, sel_valid_mask = BK.mask2idx(flattened_mask) # [*, max-count] # then back to the two dimensions sel_idxes, sel_lab_idxes = f_idxes // dim_type, f_idxes % dim_type # the embeddings sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: sel_lab_embeds = BK.zeros(sel_shape + [self.conf.lab_conf.n_dim]) else: assert not self.hl.conf.use_lookup_soft, "Cannot do soft-lookup in this mode" sel_lab_embeds = self.hl.lookup(sel_lab_idxes) return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def _exclude_nil(self, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs=None, sel_items_arr=None): # todo(note): assure that nil is 0 sel_valid_mask = sel_valid_mask * (sel_lab_idxes != 0).float() # not inplaced # idx on idx s2_idxes, s2_valid_mask = BK.mask2idx(sel_valid_mask) sel_idxes = sel_idxes.gather(-1, s2_idxes) sel_valid_mask = s2_valid_mask sel_lab_idxes = sel_lab_idxes.gather(-1, s2_idxes) sel_lab_embeds = BK.gather_first_dims(sel_lab_embeds, s2_idxes, -2) sel_logprobs = None if sel_logprobs is None else sel_logprobs.gather( -1, s2_idxes) sel_items_arr = None if sel_items_arr is None \ else sel_items_arr[np.arange(len(sel_items_arr))[:, np.newaxis], BK.get_value(s2_idxes)] return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs, sel_items_arr
def predict(self, insts: List, input_lexi, input_expr, input_mask): conf = self.conf # step 1: select mention candidates if conf.use_selector: sel_mask = self.sel.predict(input_expr, input_mask) else: sel_mask = input_mask sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask) # [*, max-count] # step 2: encoding and labeling sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask, sel_idxes) sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict( sel_hid_exprs, None) # [*, mc], [*, mc, D] # ===== if self.use_secondary_type: sectype_embeds = self.t1tot2(sel_lab_idxes) # [*, mc, D] sel2_input = sel_hid_exprs + sectype_embeds # [*, mc, D] sel2_lab_logprobs, sel2_lab_idxes, sel2_lab_embeds = self.hl.predict( sel2_input, None) if conf.sectype_t2ift1: sel2_lab_idxes *= ( sel_lab_idxes > 0).long() # pred t2 only if t1 is not 0 (nil) # first concat here and then exclude nil at one pass # [*, mc*2, ~] if sel2_lab_idxes.sum().item() > 0: # if there are any predictions sel_lab_logprobs = BK.concat( [sel_lab_logprobs, sel2_lab_logprobs], -1) sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat([sel_valid_mask, sel_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat([sel_lab_embeds, sel2_lab_embeds], -2) # ===== # step 3: exclude nil and return if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_lab_logprobs, _ = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs=sel_lab_logprobs) # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] return sel_lab_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.): conf = self.conf bsize = len(insts) # first get gold info, also multiple valid-masks gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h( insts) input_mask = input_mask * gold_valid.unsqueeze(-1) # [*, slen] # step 1: selector if conf.use_selector: sel_loss, sel_mask = self.sel.loss(input_expr, input_mask, gold_masks, margin=margin) else: sel_loss, sel_mask = None, self._select_cands_training( input_mask, gold_masks, conf.train_min_rate) sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask) # [*, max-count] sel_gold_idxes = gold_idxes.gather(-1, sel_idxes) sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes) # todo(+N): only get items by head position! _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value( sel_idxes) sel_items = gold_items_arr[_tmp_i0, _tmp_i1] # [*, mc] sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1] # step 2: encoding and labeling # if we select nothing # ----- debug # zlog(f"fb-extractor 1: shape sel_idxes = {sel_idxes.shape}") # ----- sel_shape = BK.get_shape(sel_idxes) if sel_shape[-1] == 0: lab_loss = [[BK.zeros([]), BK.zeros([])]] sel2_lab_loss = [[BK.zeros([]), BK.zeros([])] ] if self.use_secondary_type else None sel_lab_idxes = sel_gold_idxes sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim]) ret_items = sel_items # dim-1==0 else: sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask, sel_idxes) # [*, mc, DLab] lab_loss, sel_lab_idxes, sel_lab_embeds = self.hl.loss( sel_hid_exprs, sel_valid_mask, sel_gold_idxes, margin=margin) if conf.train_gold_corr: sel_lab_idxes = sel_gold_idxes if not self.hl.conf.use_lookup_soft: sel_lab_embeds = self.hl.lookup(sel_lab_idxes) ret_items = sel_items # ===== if self.use_secondary_type: sectype_embeds = self.t1tot2(sel_lab_idxes) # [*, mc, D] if conf.sectype_noback_enc: sel2_input = sel_hid_exprs.detach( ) + sectype_embeds # [*, mc, D] else: sel2_input = sel_hid_exprs + sectype_embeds # [*, mc, D] # ===== # sepcial for the sectype mask (sample it within the gold ones) sel2_valid_mask = self._select_cands_training( (sel_gold_idxes > 0).float(), (sel_gold_idxes2 > 0).float(), conf.train_min_rate_s2) # ===== sel2_lab_loss, sel2_lab_idxes, sel2_lab_embeds = self.hl.loss( sel2_input, sel2_valid_mask, sel_gold_idxes2, margin=margin) if conf.train_gold_corr: sel2_lab_idxes = sel_gold_idxes2 if not self.hl.conf.use_lookup_soft: sel2_lab_embeds = self.hl.lookup(sel2_lab_idxes) if conf.sectype_t2ift1: sel2_lab_idxes = sel2_lab_idxes * (sel_lab_idxes > 0).long( ) # pred t2 only if t1 is not 0 (nil) # combine the two if sel2_lab_idxes.sum().item( ) > 0: # if there are any gold sectypes ret_items = np.concatenate([ret_items, sel2_items], -1) # [*, mc*2] sel_idxes = BK.concat([sel_idxes, sel_idxes], -1) sel_valid_mask = BK.concat( [sel_valid_mask, sel2_valid_mask], -1) sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1) sel_lab_embeds = BK.concat( [sel_lab_embeds, sel2_lab_embeds], -2) else: sel2_lab_loss = None # ===== # step 3: exclude nil and return if conf.exclude_nil: # [*, mc', ...] sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \ self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items) # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2) # [*, mc', D] # step 4: finally prepare loss and items for one_loss in lab_loss: one_loss[0] *= conf.lambda_ne ret_losses = lab_loss if sel2_lab_loss is not None: for one_loss in sel2_lab_loss: one_loss[0] *= conf.lambda_ne2 ret_losses = ret_losses + sel2_lab_loss if sel_loss is not None: for one_loss in sel_loss: one_loss[0] *= conf.lambda_ns ret_losses = ret_losses + sel_loss # ----- debug # zlog(f"fb-extractor 2: shape sel_idxes = {sel_idxes.shape}") # ----- # mask out invalid items with None ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None return ret_losses, ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
def forward_batch(self, batched_ids: List, batched_starts: List, batched_typeids: List, training: bool, other_inputs: List[List] = None): conf = self.bconf tokenizer = self.tokenizer PAD_IDX = tokenizer.pad_token_id MASK_IDX = tokenizer.mask_token_id CLS_IDX = tokenizer.cls_token_id SEP_IDX = tokenizer.sep_token_id if other_inputs is None: other_inputs = [] # ===== # batch: here add CLS and SEP bsize = len(batched_ids) max_len = max(len(z) for z in batched_ids) + 2 # plus [CLS] and [SEP] input_shape = (bsize, max_len) # first collect on CPU input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64) input_ids_arr[:, 0] = CLS_IDX input_mask_arr = np.full(input_shape, 0, dtype=np.float32) input_is_start_arr = np.full(input_shape, 0, dtype=np.int64) input_typeids = None if batched_typeids is None else np.full( input_shape, 0, dtype=np.int64) other_input_arrs = [ np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs ] if conf.bert2_retinc_cls: # act as the ROOT word input_is_start_arr[:, 0] = 1 training_mask_rate = conf.bert2_training_mask_rate if training else 0. self_sample_stream = self.random_sample_stream for bidx in range(bsize): cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx] cur_end = len(cur_ids) + 2 # plus CLS and SEP if training_mask_rate > 0.: # input dropout input_ids_arr[bidx, 1:cur_end] = [ (MASK_IDX if next(self_sample_stream) < training_mask_rate else z) for z in cur_ids ] + [SEP_IDX] else: input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX] input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts input_mask_arr[bidx, :cur_end] = 1. if batched_typeids is not None and batched_typeids[ bidx] is not None: input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx] for one_other_input_arr, one_other_input_list in zip( other_input_arrs, other_inputs): one_other_input_arr[bidx, 1:cur_end - 1] = one_other_input_list[bidx] # arr to tensor input_ids_t = BK.input_idx(input_ids_arr) input_mask_t = BK.input_real(input_mask_arr) input_is_start_t = BK.input_idx(input_is_start_arr) input_typeid_t = None if input_typeids is None else BK.input_idx( input_typeids) other_input_ts = [BK.input_idx(z) for z in other_input_arrs] # ===== # forward (maybe need multiple times to fit maxlen constraint) MAX_LEN = 510 # save two for [CLS] and [SEP] BACK_LEN = 100 # for splitting cases, still remaining some of previous sub-tokens for context if max_len <= MAX_LEN: # directly once final_outputs = self.forward_features( input_ids_t, input_mask_t, input_typeid_t, other_input_ts) # [bs, slen, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t.float()) # [bsize, ?] else: all_outputs = [] cur_sub_idx = 0 slice_size = [bsize, 1] slice_cls, slice_sep = BK.constants(slice_size, CLS_IDX, dtype=BK.int64), BK.constants( slice_size, SEP_IDX, dtype=BK.int64) while cur_sub_idx < max_len - 1: # minus 1 to ignore ending SEP cur_slice_start = max(1, cur_sub_idx - BACK_LEN) cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1) cur_input_ids_t = BK.concat([ slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end], slice_sep ], 1) # here we simply extend extra original masks cur_input_mask_t = input_mask_t[:, cur_slice_start - 1:cur_slice_end + 1] cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:, cur_slice_start - 1: cur_slice_end + 1] cur_other_input_ts = [ z[:, cur_slice_start - 1:cur_slice_end + 1] for z in other_input_ts ] cur_outputs = self.forward_features(cur_input_ids_t, cur_input_mask_t, cur_input_typeid_t, cur_other_input_ts) # only include CLS in the first run, no SEP included if cur_sub_idx == 0: # include CLS, exclude SEP all_outputs.append(cur_outputs[:, :-1]) else: # include only new ones, discard BACK ones, exclude CLS, SEP all_outputs.append(cur_outputs[:, cur_sub_idx - cur_slice_start + 1:-1]) zwarn( f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] " f"for all-len={max_len}") cur_sub_idx = cur_slice_end final_outputs = BK.concat(all_outputs, 1) # [bs, max_len-1, *...] start_idxes, start_masks = BK.mask2idx( input_is_start_t[:, :-1].float()) # [bsize, ?] start_expr = BK.gather_first_dims(final_outputs, start_idxes, 1) # [bsize, ?, *...] return start_expr, start_masks # [bsize, ?, ...], [bsize, ?]
def loss(self, repr_ts, input_erase_mask_arr, orig_map: Dict, active_hid=True, **kwargs): conf = self.conf _tie_input_embeddings = conf.tie_input_embeddings # prepare idxes for the masked ones if self.add_root_token: # offset for the special root added in embedder mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr), padding_idx=-1) # [bsize, ?] repr_mask_idxes = mask_idxes + 1 mask_idxes.clamp_(min=0) else: mask_idxes, mask_valids = BK.mask2idx( BK.input_real(input_erase_mask_arr)) # [bsize, ?] repr_mask_idxes = mask_idxes # get the losses if BK.get_shape(mask_idxes, -1) == 0: # no loss return self._compile_component_loss("mlm", []) else: if not isinstance(repr_ts, (List, Tuple)): repr_ts = [repr_ts] target_word_scores, target_pos_scores = [], [] target_pos_scores = None # todo(+N): for simplicity, currently ignore this one!! for layer_idx in conf.loss_layers: # calculate scores target_reprs = BK.gather_first_dims(repr_ts[layer_idx], repr_mask_idxes, 1) # [bsize, ?, *] if self.hid_layer and active_hid: # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside! target_hids = self.hid_layer(target_reprs) else: target_hids = target_reprs if _tie_input_embeddings: pred_W = self.inputter_word_node.E.E[:self. pred_word_size] # [PSize, Dim] target_word_scores.append(BK.matmul( target_hids, pred_W.T)) # List[bsize, ?, Vw] else: target_word_scores.append(self.pred_word_layer( target_hids)) # List[bsize, ?, Vw] # gather the losses all_losses = [] for pred_name, target_scores, loss_lambda, range_min, range_max in \ zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos], [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]): if loss_lambda > 0.: seq_idx_t = BK.input_idx( orig_map[pred_name]) # [bsize, slen] target_idx_t = seq_idx_t.gather(-1, mask_idxes) # [bsize, ?] ranged_mask_valids = mask_valids * ( target_idx_t >= range_min).float() * ( target_idx_t <= range_max).float() target_idx_t[(ranged_mask_valids < 1.)] = 0 # make sure invalid ones in range # calculate for each layer all_layer_losses, all_layer_scores = [], [] for one_layer_idx, one_target_scores in enumerate( target_scores): # get loss: [bsize, ?] one_pred_losses = BK.loss_nll( one_target_scores, target_idx_t) * conf.loss_weights[one_layer_idx] all_layer_losses.append(one_pred_losses) # get scores one_pred_scores = BK.log_softmax( one_target_scores, -1) * conf.loss_weights[one_layer_idx] all_layer_scores.append(one_pred_scores) # combine all layers pred_losses = self.loss_comb_f(all_layer_losses) pred_loss_sum = (pred_losses * ranged_mask_valids).sum() pred_loss_count = ranged_mask_valids.sum() # argmax _, argmax_idxes = self.score_comb_f(all_layer_scores).max( -1) pred_corrs = (argmax_idxes == target_idx_t).float() * ranged_mask_valids pred_corr_count = pred_corrs.sum() # compile leaf loss r_loss = LossHelper.compile_leaf_info( pred_name, pred_loss_sum, pred_loss_count, loss_lambda=loss_lambda, corr=pred_corr_count) all_losses.append(r_loss) return self._compile_component_loss("mlm", all_losses)