Example #1
0
 def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr):
     mask_idxes, mask_valids = BK.mask2idx(
         BK.input_real(pred_mask_repl_arr))  # [bsize, ?]
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         zzz = BK.zeros([])
         return [[zzz, zzz, zzz]]
     else:
         target_reprs = BK.gather_first_dims(repr_t, mask_idxes,
                                             1)  # [bsize, ?, *]
         target_hids = self.hid_layer(target_reprs)
         target_scores = self.pred_layer(target_hids)  # [bsize, ?, V]
         pred_idx_t = BK.input_idx(pred_idx_arr)  # [bsize, slen]
         target_idx_t = pred_idx_t.gather(-1, mask_idxes)  # [bsize, ?]
         target_idx_t[(mask_valids <
                       1.)] = 0  # make sure invalid ones in range
         # get loss
         pred_losses = BK.loss_nll(target_scores,
                                   target_idx_t)  # [bsize, ?]
         pred_loss_sum = (pred_losses * mask_valids).sum()
         pred_loss_count = mask_valids.sum()
         # argmax
         _, argmax_idxes = target_scores.max(-1)
         pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids
         pred_corr_count = pred_corrs.sum()
         return [[pred_loss_sum, pred_loss_count, pred_corr_count]]
Example #2
0
 def lookup(self, insts: List, input_lexi, input_expr, input_mask):
     conf = self.conf
     bsize = len(insts)
     # first get gold/input info, also multiple valid-masks
     gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h(
         insts)
     # step 1: no selection, simply forward using gold_masks
     sel_idxes, sel_valid_mask = BK.mask2idx(gold_masks)  # [*, max-count]
     sel_gold_idxes = gold_idxes.gather(-1, sel_idxes)
     sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes)
     # todo(+N): only get items by head position!
     _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value(
         sel_idxes)
     sel_items = gold_items_arr[_tmp_i0, _tmp_i1]  # [*, mc]
     sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1]
     # step 2: encoding and labeling
     sel_shape = BK.get_shape(sel_idxes)
     if sel_shape[-1] == 0:
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim])
         ret_items = sel_items  # dim-1==0
     else:
         # sel_hid_exprs = self._enc(input_expr, input_mask, sel_idxes)  # [*, mc, DLab]
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = self.hl.lookup(
             sel_lab_idxes)  # todo(note): here no softlookup?
         ret_items = sel_items
         # second type
         if self.use_secondary_type:
             sel2_lab_idxes = sel_gold_idxes2
             sel2_lab_embeds = self.hl.lookup(
                 sel2_lab_idxes)  # todo(note): here no softlookup?
             sel2_valid_mask = (sel2_lab_idxes > 0).float()
             # combine the two
             if sel2_lab_idxes.sum().item(
             ) > 0:  # if there are any gold sectypes
                 ret_items = np.concatenate([ret_items, sel2_items],
                                            -1)  # [*, mc*2]
                 sel_idxes = BK.concat([sel_idxes, sel_idxes], -1)
                 sel_valid_mask = BK.concat(
                     [sel_valid_mask, sel2_valid_mask], -1)
                 sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes],
                                           -1)
                 sel_lab_embeds = BK.concat(
                     [sel_lab_embeds, sel2_lab_embeds], -2)
     # step 3: exclude nil assuming no deliberate nil in gold/inputs
     if conf.exclude_nil:  # [*, mc', ...]
         sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \
             self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items)
     # step 4: return
     # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2)  # [*, mc', D]
     # mask out invalid items with None
     ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None
     return ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
Example #3
0
 def _pmask2idxes(self, pred_mask):
     orig_shape = BK.get_shape(pred_mask)
     dim_type = orig_shape[-1]
     flattened_mask = pred_mask.view(orig_shape[:-2] + [-1])  # [*, slen*L]
     f_idxes, sel_valid_mask = BK.mask2idx(flattened_mask)  # [*, max-count]
     # then back to the two dimensions
     sel_idxes, sel_lab_idxes = f_idxes // dim_type, f_idxes % dim_type
     # the embeddings
     sel_shape = BK.get_shape(sel_idxes)
     if sel_shape[-1] == 0:
         sel_lab_embeds = BK.zeros(sel_shape + [self.conf.lab_conf.n_dim])
     else:
         assert not self.hl.conf.use_lookup_soft, "Cannot do soft-lookup in this mode"
         sel_lab_embeds = self.hl.lookup(sel_lab_idxes)
     return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
Example #4
0
 def _exclude_nil(self,
                  sel_idxes,
                  sel_valid_mask,
                  sel_lab_idxes,
                  sel_lab_embeds,
                  sel_logprobs=None,
                  sel_items_arr=None):
     # todo(note): assure that nil is 0
     sel_valid_mask = sel_valid_mask * (sel_lab_idxes !=
                                        0).float()  # not inplaced
     # idx on idx
     s2_idxes, s2_valid_mask = BK.mask2idx(sel_valid_mask)
     sel_idxes = sel_idxes.gather(-1, s2_idxes)
     sel_valid_mask = s2_valid_mask
     sel_lab_idxes = sel_lab_idxes.gather(-1, s2_idxes)
     sel_lab_embeds = BK.gather_first_dims(sel_lab_embeds, s2_idxes, -2)
     sel_logprobs = None if sel_logprobs is None else sel_logprobs.gather(
         -1, s2_idxes)
     sel_items_arr = None if sel_items_arr is None \
         else sel_items_arr[np.arange(len(sel_items_arr))[:, np.newaxis], BK.get_value(s2_idxes)]
     return sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs, sel_items_arr
Example #5
0
 def predict(self, insts: List, input_lexi, input_expr, input_mask):
     conf = self.conf
     # step 1: select mention candidates
     if conf.use_selector:
         sel_mask = self.sel.predict(input_expr, input_mask)
     else:
         sel_mask = input_mask
     sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask)  # [*, max-count]
     # step 2: encoding and labeling
     sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask,
                               sel_idxes)
     sel_lab_logprobs, sel_lab_idxes, sel_lab_embeds = self.hl.predict(
         sel_hid_exprs, None)  # [*, mc], [*, mc, D]
     # =====
     if self.use_secondary_type:
         sectype_embeds = self.t1tot2(sel_lab_idxes)  # [*, mc, D]
         sel2_input = sel_hid_exprs + sectype_embeds  # [*, mc, D]
         sel2_lab_logprobs, sel2_lab_idxes, sel2_lab_embeds = self.hl.predict(
             sel2_input, None)
         if conf.sectype_t2ift1:
             sel2_lab_idxes *= (
                 sel_lab_idxes >
                 0).long()  # pred t2 only if t1 is not 0 (nil)
         # first concat here and then exclude nil at one pass # [*, mc*2, ~]
         if sel2_lab_idxes.sum().item() > 0:  # if there are any predictions
             sel_lab_logprobs = BK.concat(
                 [sel_lab_logprobs, sel2_lab_logprobs], -1)
             sel_idxes = BK.concat([sel_idxes, sel_idxes], -1)
             sel_valid_mask = BK.concat([sel_valid_mask, sel_valid_mask],
                                        -1)
             sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes], -1)
             sel_lab_embeds = BK.concat([sel_lab_embeds, sel2_lab_embeds],
                                        -2)
     # =====
     # step 3: exclude nil and return
     if conf.exclude_nil:  # [*, mc', ...]
         sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_lab_logprobs, _ = \
             self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_logprobs=sel_lab_logprobs)
     # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2)  # [*, mc', D]
     return sel_lab_logprobs, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
Example #6
0
 def loss(self, insts: List, input_lexi, input_expr, input_mask, margin=0.):
     conf = self.conf
     bsize = len(insts)
     # first get gold info, also multiple valid-masks
     gold_masks, gold_idxes, gold_items_arr, gold_valid, gold_idxes2, gold_items2_arr = self.batch_inputs_h(
         insts)
     input_mask = input_mask * gold_valid.unsqueeze(-1)  # [*, slen]
     # step 1: selector
     if conf.use_selector:
         sel_loss, sel_mask = self.sel.loss(input_expr,
                                            input_mask,
                                            gold_masks,
                                            margin=margin)
     else:
         sel_loss, sel_mask = None, self._select_cands_training(
             input_mask, gold_masks, conf.train_min_rate)
     sel_idxes, sel_valid_mask = BK.mask2idx(sel_mask)  # [*, max-count]
     sel_gold_idxes = gold_idxes.gather(-1, sel_idxes)
     sel_gold_idxes2 = gold_idxes2.gather(-1, sel_idxes)
     # todo(+N): only get items by head position!
     _tmp_i0, _tmp_i1 = np.arange(bsize)[:, np.newaxis], BK.get_value(
         sel_idxes)
     sel_items = gold_items_arr[_tmp_i0, _tmp_i1]  # [*, mc]
     sel2_items = gold_items2_arr[_tmp_i0, _tmp_i1]
     # step 2: encoding and labeling
     # if we select nothing
     # ----- debug
     # zlog(f"fb-extractor 1: shape sel_idxes = {sel_idxes.shape}")
     # -----
     sel_shape = BK.get_shape(sel_idxes)
     if sel_shape[-1] == 0:
         lab_loss = [[BK.zeros([]), BK.zeros([])]]
         sel2_lab_loss = [[BK.zeros([]), BK.zeros([])]
                          ] if self.use_secondary_type else None
         sel_lab_idxes = sel_gold_idxes
         sel_lab_embeds = BK.zeros(sel_shape + [conf.lab_conf.n_dim])
         ret_items = sel_items  # dim-1==0
     else:
         sel_hid_exprs = self._enc(input_lexi, input_expr, input_mask,
                                   sel_idxes)  # [*, mc, DLab]
         lab_loss, sel_lab_idxes, sel_lab_embeds = self.hl.loss(
             sel_hid_exprs, sel_valid_mask, sel_gold_idxes, margin=margin)
         if conf.train_gold_corr:
             sel_lab_idxes = sel_gold_idxes
             if not self.hl.conf.use_lookup_soft:
                 sel_lab_embeds = self.hl.lookup(sel_lab_idxes)
         ret_items = sel_items
         # =====
         if self.use_secondary_type:
             sectype_embeds = self.t1tot2(sel_lab_idxes)  # [*, mc, D]
             if conf.sectype_noback_enc:
                 sel2_input = sel_hid_exprs.detach(
                 ) + sectype_embeds  # [*, mc, D]
             else:
                 sel2_input = sel_hid_exprs + sectype_embeds  # [*, mc, D]
             # =====
             # sepcial for the sectype mask (sample it within the gold ones)
             sel2_valid_mask = self._select_cands_training(
                 (sel_gold_idxes > 0).float(),
                 (sel_gold_idxes2 > 0).float(), conf.train_min_rate_s2)
             # =====
             sel2_lab_loss, sel2_lab_idxes, sel2_lab_embeds = self.hl.loss(
                 sel2_input,
                 sel2_valid_mask,
                 sel_gold_idxes2,
                 margin=margin)
             if conf.train_gold_corr:
                 sel2_lab_idxes = sel_gold_idxes2
                 if not self.hl.conf.use_lookup_soft:
                     sel2_lab_embeds = self.hl.lookup(sel2_lab_idxes)
             if conf.sectype_t2ift1:
                 sel2_lab_idxes = sel2_lab_idxes * (sel_lab_idxes > 0).long(
                 )  # pred t2 only if t1 is not 0 (nil)
             # combine the two
             if sel2_lab_idxes.sum().item(
             ) > 0:  # if there are any gold sectypes
                 ret_items = np.concatenate([ret_items, sel2_items],
                                            -1)  # [*, mc*2]
                 sel_idxes = BK.concat([sel_idxes, sel_idxes], -1)
                 sel_valid_mask = BK.concat(
                     [sel_valid_mask, sel2_valid_mask], -1)
                 sel_lab_idxes = BK.concat([sel_lab_idxes, sel2_lab_idxes],
                                           -1)
                 sel_lab_embeds = BK.concat(
                     [sel_lab_embeds, sel2_lab_embeds], -2)
         else:
             sel2_lab_loss = None
         # =====
         # step 3: exclude nil and return
         if conf.exclude_nil:  # [*, mc', ...]
             sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, _, ret_items = \
                 self._exclude_nil(sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds, sel_items_arr=ret_items)
     # sel_enc_expr = BK.gather_first_dims(input_expr, sel_idxes, -2)  # [*, mc', D]
     # step 4: finally prepare loss and items
     for one_loss in lab_loss:
         one_loss[0] *= conf.lambda_ne
     ret_losses = lab_loss
     if sel2_lab_loss is not None:
         for one_loss in sel2_lab_loss:
             one_loss[0] *= conf.lambda_ne2
         ret_losses = ret_losses + sel2_lab_loss
     if sel_loss is not None:
         for one_loss in sel_loss:
             one_loss[0] *= conf.lambda_ns
         ret_losses = ret_losses + sel_loss
     # ----- debug
     # zlog(f"fb-extractor 2: shape sel_idxes = {sel_idxes.shape}")
     # -----
     # mask out invalid items with None
     ret_items[BK.get_value(1. - sel_valid_mask).astype(np.bool)] = None
     return ret_losses, ret_items, sel_idxes, sel_valid_mask, sel_lab_idxes, sel_lab_embeds
Example #7
0
 def forward_batch(self,
                   batched_ids: List,
                   batched_starts: List,
                   batched_typeids: List,
                   training: bool,
                   other_inputs: List[List] = None):
     conf = self.bconf
     tokenizer = self.tokenizer
     PAD_IDX = tokenizer.pad_token_id
     MASK_IDX = tokenizer.mask_token_id
     CLS_IDX = tokenizer.cls_token_id
     SEP_IDX = tokenizer.sep_token_id
     if other_inputs is None:
         other_inputs = []
     # =====
     # batch: here add CLS and SEP
     bsize = len(batched_ids)
     max_len = max(len(z) for z in batched_ids) + 2  # plus [CLS] and [SEP]
     input_shape = (bsize, max_len)
     # first collect on CPU
     input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64)
     input_ids_arr[:, 0] = CLS_IDX
     input_mask_arr = np.full(input_shape, 0, dtype=np.float32)
     input_is_start_arr = np.full(input_shape, 0, dtype=np.int64)
     input_typeids = None if batched_typeids is None else np.full(
         input_shape, 0, dtype=np.int64)
     other_input_arrs = [
         np.full(input_shape, 0, dtype=np.int64) for _ in other_inputs
     ]
     if conf.bert2_retinc_cls:  # act as the ROOT word
         input_is_start_arr[:, 0] = 1
     training_mask_rate = conf.bert2_training_mask_rate if training else 0.
     self_sample_stream = self.random_sample_stream
     for bidx in range(bsize):
         cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx]
         cur_end = len(cur_ids) + 2  # plus CLS and SEP
         if training_mask_rate > 0.:
             # input dropout
             input_ids_arr[bidx, 1:cur_end] = [
                 (MASK_IDX
                  if next(self_sample_stream) < training_mask_rate else z)
                 for z in cur_ids
             ] + [SEP_IDX]
         else:
             input_ids_arr[bidx, 1:cur_end] = cur_ids + [SEP_IDX]
         input_is_start_arr[bidx, 1:cur_end - 1] = cur_starts
         input_mask_arr[bidx, :cur_end] = 1.
         if batched_typeids is not None and batched_typeids[
                 bidx] is not None:
             input_typeids[bidx, 1:cur_end - 1] = batched_typeids[bidx]
         for one_other_input_arr, one_other_input_list in zip(
                 other_input_arrs, other_inputs):
             one_other_input_arr[bidx,
                                 1:cur_end - 1] = one_other_input_list[bidx]
     # arr to tensor
     input_ids_t = BK.input_idx(input_ids_arr)
     input_mask_t = BK.input_real(input_mask_arr)
     input_is_start_t = BK.input_idx(input_is_start_arr)
     input_typeid_t = None if input_typeids is None else BK.input_idx(
         input_typeids)
     other_input_ts = [BK.input_idx(z) for z in other_input_arrs]
     # =====
     # forward (maybe need multiple times to fit maxlen constraint)
     MAX_LEN = 510  # save two for [CLS] and [SEP]
     BACK_LEN = 100  # for splitting cases, still remaining some of previous sub-tokens for context
     if max_len <= MAX_LEN:
         # directly once
         final_outputs = self.forward_features(
             input_ids_t, input_mask_t, input_typeid_t,
             other_input_ts)  # [bs, slen, *...]
         start_idxes, start_masks = BK.mask2idx(
             input_is_start_t.float())  # [bsize, ?]
     else:
         all_outputs = []
         cur_sub_idx = 0
         slice_size = [bsize, 1]
         slice_cls, slice_sep = BK.constants(slice_size,
                                             CLS_IDX,
                                             dtype=BK.int64), BK.constants(
                                                 slice_size,
                                                 SEP_IDX,
                                                 dtype=BK.int64)
         while cur_sub_idx < max_len - 1:  # minus 1 to ignore ending SEP
             cur_slice_start = max(1, cur_sub_idx - BACK_LEN)
             cur_slice_end = min(cur_slice_start + MAX_LEN, max_len - 1)
             cur_input_ids_t = BK.concat([
                 slice_cls, input_ids_t[:, cur_slice_start:cur_slice_end],
                 slice_sep
             ], 1)
             # here we simply extend extra original masks
             cur_input_mask_t = input_mask_t[:, cur_slice_start -
                                             1:cur_slice_end + 1]
             cur_input_typeid_t = None if input_typeid_t is None else input_typeid_t[:,
                                                                                     cur_slice_start
                                                                                     -
                                                                                     1:
                                                                                     cur_slice_end
                                                                                     +
                                                                                     1]
             cur_other_input_ts = [
                 z[:, cur_slice_start - 1:cur_slice_end + 1]
                 for z in other_input_ts
             ]
             cur_outputs = self.forward_features(cur_input_ids_t,
                                                 cur_input_mask_t,
                                                 cur_input_typeid_t,
                                                 cur_other_input_ts)
             # only include CLS in the first run, no SEP included
             if cur_sub_idx == 0:
                 # include CLS, exclude SEP
                 all_outputs.append(cur_outputs[:, :-1])
             else:
                 # include only new ones, discard BACK ones, exclude CLS, SEP
                 all_outputs.append(cur_outputs[:, cur_sub_idx -
                                                cur_slice_start + 1:-1])
                 zwarn(
                     f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] "
                     f"for all-len={max_len}")
             cur_sub_idx = cur_slice_end
         final_outputs = BK.concat(all_outputs, 1)  # [bs, max_len-1, *...]
         start_idxes, start_masks = BK.mask2idx(
             input_is_start_t[:, :-1].float())  # [bsize, ?]
     start_expr = BK.gather_first_dims(final_outputs, start_idxes,
                                       1)  # [bsize, ?, *...]
     return start_expr, start_masks  # [bsize, ?, ...], [bsize, ?]
Example #8
0
 def loss(self,
          repr_ts,
          input_erase_mask_arr,
          orig_map: Dict,
          active_hid=True,
          **kwargs):
     conf = self.conf
     _tie_input_embeddings = conf.tie_input_embeddings
     # prepare idxes for the masked ones
     if self.add_root_token:  # offset for the special root added in embedder
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr),
             padding_idx=-1)  # [bsize, ?]
         repr_mask_idxes = mask_idxes + 1
         mask_idxes.clamp_(min=0)
     else:
         mask_idxes, mask_valids = BK.mask2idx(
             BK.input_real(input_erase_mask_arr))  # [bsize, ?]
         repr_mask_idxes = mask_idxes
     # get the losses
     if BK.get_shape(mask_idxes, -1) == 0:  # no loss
         return self._compile_component_loss("mlm", [])
     else:
         if not isinstance(repr_ts, (List, Tuple)):
             repr_ts = [repr_ts]
         target_word_scores, target_pos_scores = [], []
         target_pos_scores = None  # todo(+N): for simplicity, currently ignore this one!!
         for layer_idx in conf.loss_layers:
             # calculate scores
             target_reprs = BK.gather_first_dims(repr_ts[layer_idx],
                                                 repr_mask_idxes,
                                                 1)  # [bsize, ?, *]
             if self.hid_layer and active_hid:  # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside!
                 target_hids = self.hid_layer(target_reprs)
             else:
                 target_hids = target_reprs
             if _tie_input_embeddings:
                 pred_W = self.inputter_word_node.E.E[:self.
                                                      pred_word_size]  # [PSize, Dim]
                 target_word_scores.append(BK.matmul(
                     target_hids, pred_W.T))  # List[bsize, ?, Vw]
             else:
                 target_word_scores.append(self.pred_word_layer(
                     target_hids))  # List[bsize, ?, Vw]
         # gather the losses
         all_losses = []
         for pred_name, target_scores, loss_lambda, range_min, range_max in \
                 zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos],
                     [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]):
             if loss_lambda > 0.:
                 seq_idx_t = BK.input_idx(
                     orig_map[pred_name])  # [bsize, slen]
                 target_idx_t = seq_idx_t.gather(-1,
                                                 mask_idxes)  # [bsize, ?]
                 ranged_mask_valids = mask_valids * (
                     target_idx_t >= range_min).float() * (
                         target_idx_t <= range_max).float()
                 target_idx_t[(ranged_mask_valids <
                               1.)] = 0  # make sure invalid ones in range
                 # calculate for each layer
                 all_layer_losses, all_layer_scores = [], []
                 for one_layer_idx, one_target_scores in enumerate(
                         target_scores):
                     # get loss: [bsize, ?]
                     one_pred_losses = BK.loss_nll(
                         one_target_scores,
                         target_idx_t) * conf.loss_weights[one_layer_idx]
                     all_layer_losses.append(one_pred_losses)
                     # get scores
                     one_pred_scores = BK.log_softmax(
                         one_target_scores,
                         -1) * conf.loss_weights[one_layer_idx]
                     all_layer_scores.append(one_pred_scores)
                 # combine all layers
                 pred_losses = self.loss_comb_f(all_layer_losses)
                 pred_loss_sum = (pred_losses * ranged_mask_valids).sum()
                 pred_loss_count = ranged_mask_valids.sum()
                 # argmax
                 _, argmax_idxes = self.score_comb_f(all_layer_scores).max(
                     -1)
                 pred_corrs = (argmax_idxes
                               == target_idx_t).float() * ranged_mask_valids
                 pred_corr_count = pred_corrs.sum()
                 # compile leaf loss
                 r_loss = LossHelper.compile_leaf_info(
                     pred_name,
                     pred_loss_sum,
                     pred_loss_count,
                     loss_lambda=loss_lambda,
                     corr=pred_corr_count)
                 all_losses.append(r_loss)
         return self._compile_component_loss("mlm", all_losses)