Ejemplo n.º 1
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PosiInputEmbedderConf = self.conf
     # --
     try:
         # input is a shape as prepared by "PosiHelper"
         batch_size, max_len = inputs
         if add_bos:
             max_len += 1
         if add_eos:
             max_len += 1
         posi_idxes = BK.arange_idx(max_len)  # [?len?]
         ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1)
     except:
         # input is tensor
         posi_idxes = BK.input_idx(inputs)  # [*, len]
         cur_maxlen = BK.get_shape(posi_idxes, -1)
         # --
         all_input_slices = []
         slice_shape = BK.get_shape(posi_idxes)[:-1] + [1]
         if add_bos:  # add 0 and offset
             all_input_slices.append(
                 BK.constants(slice_shape, 0, dtype=posi_idxes.dtype))
             cur_maxlen += 1
             posi_idxes += 1
         all_input_slices.append(posi_idxes)  # [*, len]
         if add_eos:
             all_input_slices.append(
                 BK.constants(slice_shape,
                              cur_maxlen,
                              dtype=posi_idxes.dtype))
         final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
         # finally
         ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Ejemplo n.º 2
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PlainInputEmbedderConf = self.conf
     # --
     voc = self.voc
     input_t = BK.input_idx(inputs)  # [*, len]
     # rare unk in training
     if self.is_training() and self.use_rare_unk:
         rare_unk_rate = conf.rare_unk_rate
         cur_unk_imask = (
             self.rare_unk_mask[input_t] *
             (BK.rand(BK.get_shape(input_t)) < rare_unk_rate)).long()
         input_t = input_t * (1 - cur_unk_imask) + voc.unk * cur_unk_imask
     # bos and eos
     all_input_slices = []
     slice_shape = BK.get_shape(input_t)[:-1] + [1]
     if add_bos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.bos, dtype=input_t.dtype))
     all_input_slices.append(input_t)  # [*, len]
     if add_eos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.eos, dtype=input_t.dtype))
     final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
     # finally
     ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Ejemplo n.º 3
0
 def _aug_ends(
     self, t: BK.Expr, BOS, PAD, EOS, dtype
 ):  # add BOS(CLS) and EOS(SEP) for a tensor (sub_len -> 1+sub_len+1)
     slice_shape = [self.bsize, 1]
     slices = [
         BK.constants(slice_shape, BOS, dtype=dtype), t,
         BK.constants(slice_shape, PAD, dtype=dtype)
     ]
     aug_batched_ids = BK.concat(slices, -1)  # [bsize, 1+sub_len+1]
     aug_batched_ids[self.arange1_t,
                     self.batched_sublens_p1] = EOS  # assign EOS
     return aug_batched_ids
Ejemplo n.º 4
0
 def _get_extra_score(self, cand_score, insts, cand_res, arr_gold_items, use_cons: bool, use_lu: bool):
     # conf: DirectExtractorConf = self.conf
     # --
     # first cand score
     cand_score = self._extend_cand_score(cand_score)
     # then cons_lex score
     cons_lex_node = self.cons_lex_node
     if use_cons and cons_lex_node is not None:
         cons_lex = cons_lex_node.cons
         flt_arr_gold_items = arr_gold_items.flatten()
         _shape = BK.get_shape(cand_res.mask_expr)
         if cand_res.gaddr_expr is None:
             gaddr_expr = BK.constants(_shape, -1, dtype=BK.long)
         else:
             gaddr_expr = cand_res.gaddr_expr
         all_arrs = [BK.get_value(z) for z in [cand_res.widx_expr, cand_res.wlen_expr, cand_res.mask_expr, gaddr_expr]]
         arr_feats = np.full(_shape, None, dtype=object)
         for bidx, inst in enumerate(insts):
             one_arr_feats = arr_feats[bidx]
             _ii = -1
             for one_widx, one_wlen, one_mask, one_gaddr in zip(*[z[bidx] for z in all_arrs]):
                 _ii += 1
                 if one_mask == 0.: continue  # skip invlaid ones
                 if use_lu and one_gaddr>=0:
                     one_feat = cons_lex.lu2feat(flt_arr_gold_items[one_gaddr].info["luName"])
                 else:
                     one_feat = cons_lex.span2feat(inst, one_widx, one_wlen)
                 one_arr_feats[_ii] = one_feat
         cons_valids = cons_lex_node.lookup_with_feats(arr_feats)
         cons_score = (1.-cons_valids) * Constants.REAL_PRAC_MIN
     else:
         cons_score = None
     # sum
     return self._sum_scores(cand_score, cons_score)
Ejemplo n.º 5
0
 def forward_embedding(self, input_ids, attention_mask, token_type_ids,
                       position_ids, other_embeds):
     input_shape = input_ids.size()  # [bsize, len]
     seq_length = input_shape[1]
     if position_ids is None:
         position_ids = BK.arange_idx(seq_length)  # [len]
         position_ids = position_ids.unsqueeze(0).expand(
             input_shape)  # [bsize, len]
     if token_type_ids is None:
         token_type_ids = BK.zeros(input_shape).long()  # [bsize, len]
     # BertEmbeddings.forward
     _embeddings = self.model.embeddings
     inputs_embeds = _embeddings.word_embeddings(
         input_ids)  # [bsize, len, D]
     position_embeddings = _embeddings.position_embeddings(
         position_ids)  # [bsize, len, D]
     token_type_embeddings = _embeddings.token_type_embeddings(
         token_type_ids)  # [bsize, len, D]
     embeddings = inputs_embeds + position_embeddings + token_type_embeddings
     if other_embeds is not None:
         embeddings += other_embeds
     embeddings = _embeddings.LayerNorm(embeddings)
     embeddings = _embeddings.dropout(embeddings)
     # prepare attention_mask
     if attention_mask is None:
         attention_mask = BK.constants(input_shape, value=1.)
     assert attention_mask.dim() == 2
     extended_attention_mask = attention_mask[:, None, None, :]
     extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
     return embeddings, extended_attention_mask
Ejemplo n.º 6
0
 def _determine_size(self, length: BK.Expr, rate: float, count: float):
     if rate is None:
         ret_size = BK.constants(length.shape, count)  # make it constant
     else:
         ret_size = (length * rate)
         if count is not None:
             ret_size.clamp_(max=count)
     ret_size.ceil_()
     return ret_size
Ejemplo n.º 7
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: CharCnnInputEmbedderConf = self.conf
     # --
     voc = self.voc
     char_input_t = BK.input_idx(inputs)  # [*, len]
     # todo(note): no need for replacing to unk for char!!
     # bos and eos
     all_input_slices = []
     slice_shape = BK.get_shape(char_input_t)
     slice_shape[-2] = 1  # [*, 1, clen]
     if add_bos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.bos, dtype=char_input_t.dtype))
     all_input_slices.append(char_input_t)  # [*, len, clen]
     if add_eos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.eos, dtype=char_input_t.dtype))
     final_input_t = BK.concat(all_input_slices, -2)  # [*, 1?+len+1?, clen]
     # char embeddings
     char_embed_expr = self.E(final_input_t)  # [*, ??, dim]
     # char cnn
     ret = self.cnn(char_embed_expr)
     return ret
Ejemplo n.º 8
0
 def s0_open_new_steps(self, bsize: int, ssize: int, mask: BK.Expr = None):
     assert ssize > 0
     assert self._cur_layer_idx == -1
     self._cur_layer_idx = 0
     # --
     new_mask = BK.constants([bsize, ssize], 1.) if mask is None else mask  # [*, ssize]
     # --
     # prepare for store_lstate selecting
     if len(self.cum_state_lset) > 0:  # any layer need to accumulat?
         self._arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
         # note: if no last state, simply clamp 0, otherwise, offset by 1 since we will concat later
         self._arange_sel_t = mask2posi_padded(new_mask, 0, 0) if mask is None else mask2posi_padded(new_mask, 1, 0)
     # prev_steps = self.steps  # previous accumulated steps
     self.steps += ssize
     self.mask = new_mask if self.mask is None else BK.concat([self.mask, new_mask], 1)  # [*, old+new]
     self.positions = mask2posi(self.mask, offset=-1, cmin=0)  # [*, old+new], recalculate!!
Ejemplo n.º 9
0
def nmst_greedy(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False):
    assert labeled
    with BK.no_grad_env():
        scores_shape = BK.get_shape(scores_expr)
        maxlen = scores_shape[1]
        # mask out diag
        scores_expr += BK.diagflat(BK.constants([maxlen], Constants.REAL_PRAC_MIN)).unsqueeze(-1)
        # combined last two dimension and Max over them
        combined_scores_expr = scores_expr.view(scores_shape[:-2] + [-1])
        combine_max_scores, combined_max_idxes = BK.max(combined_scores_expr, dim=-1)
        # back to real idxes
        last_size = scores_shape[-1]
        greedy_heads = combined_max_idxes // last_size
        greedy_labels = combined_max_idxes % last_size
        if ret_arr:
            mst_heads_arr, mst_labels_arr, mst_scores_arr = [BK.get_value(z) for z in (greedy_heads, greedy_labels, combine_max_scores)]
            return mst_heads_arr, mst_labels_arr, mst_scores_arr
        else:
            return greedy_heads, greedy_labels, combine_max_scores
Ejemplo n.º 10
0
 def loss(self, flt_items, flt_input_expr, flt_pair_expr, flt_full_expr, flt_mask_expr, flt_extra_weights=None):
     conf: ExtenderConf = self.conf
     _loss_lambda = conf._loss_lambda
     # --
     enc_t = self._forward_eenc(flt_input_expr, flt_full_expr, flt_mask_expr)  # [*, slen, D]
     s_left, s_right = self.enode.score(enc_t, flt_pair_expr if conf.ext_use_finput else None, flt_mask_expr)  # [*, slen]
     # --
     gold_posi = [self.ext_span_getter(z.mention) for z in flt_items]  # List[(widx, wlen)]
     widx_t = BK.input_idx([z[0] for z in gold_posi])  # [*]
     wlen_t = BK.input_idx([z[1] for z in gold_posi])
     loss_left_t, loss_right_t = BK.loss_nll(s_left, widx_t), BK.loss_nll(s_right, widx_t+wlen_t-1)  # [*]
     if flt_extra_weights is not None:
         loss_left_t *= flt_extra_weights
         loss_right_t *= flt_extra_weights
         loss_div = flt_extra_weights.sum()  # note: also use this!
     else:
         loss_div = BK.constants([len(flt_items)], value=1.).sum()
     loss_left_item = LossHelper.compile_leaf_loss("left", loss_left_t.sum(), loss_div, loss_lambda=_loss_lambda)
     loss_right_item = LossHelper.compile_leaf_loss("right", loss_right_t.sum(), loss_div, loss_lambda=_loss_lambda)
     ret_loss = LossHelper.combine_multiple_losses([loss_left_item, loss_right_item])
     return ret_loss
Ejemplo n.º 11
0
 def _split_extend(self, split_decisions: BK.Expr, cand_mask: BK.Expr):
     # first augment/pad split_decisions
     slice_ones = BK.constants([BK.get_shape(split_decisions, 0), 1],
                               1.)  # [*, 1]
     padded_split_decisions = BK.concat([slice_ones, split_decisions],
                                        -1)  # [*, clen]
     seg_cidxes, seg_masks = BK.mask2idx(
         padded_split_decisions)  # [*, seglen]
     # --
     cand_lens = cand_mask.sum(-1, keepdim=True).long()  # [*, 1]
     seg_masks *= (cand_lens > 0).float()  # for the case of no cands
     # --
     seg_cidxes_special = seg_cidxes + (1. - seg_masks).long(
     ) * cand_lens  # [*, seglen], fill in for paddings
     seg_cidxes_special2 = BK.concat([seg_cidxes_special, cand_lens],
                                     -1)  # [*, seglen+1]
     seg_clens = seg_cidxes_special2[:,
                                     1:] - seg_cidxes_special  # [*, seglen]
     # extend the idxes
     seg_ext_cidxes, seg_ext_masks = expand_ranged_idxes(
         seg_cidxes, seg_clens)  # [*, seglen, MW]
     seg_ext_masks *= seg_masks.unsqueeze(-1)
     return seg_ext_cidxes, seg_ext_masks, seg_masks  # 2x[*, seglen, MW], [*, seglen]
Ejemplo n.º 12
0
 def forward(self, input_map: Dict):
     add_bos, add_eos = self.conf.add_bos, self.conf.add_eos
     ret = OrderedDict()  # [*, len, ?]
     for key, embedder_pack in self.embedders.items(
     ):  # according to REG order!!
         embedder, input_name = embedder_pack
         one_expr = embedder(input_map[input_name],
                             add_bos=add_bos,
                             add_eos=add_eos)
         ret[key] = one_expr
     # mask expr
     mask_expr = input_map.get("mask")
     if mask_expr is not None:
         all_input_slices = []
         mask_slice = BK.constants(BK.get_shape(mask_expr)[:-1] + [1],
                                   1,
                                   dtype=mask_expr.dtype)  # [*, 1]
         if add_bos:
             all_input_slices.append(mask_slice)
         all_input_slices.append(mask_expr)
         if add_eos:
             all_input_slices.append(mask_slice)
         mask_expr = BK.concat(all_input_slices, -1)  # [*, ?+len+?]
     return mask_expr, ret
Ejemplo n.º 13
0
 def beam_search(self, batch_size: int, beam_k: int, ret_best: bool = True):
     _NEG_INF = Constants.REAL_PRAC_MIN
     # --
     cur_step = 0
     cache: DecCache = None
     # init: keep the seq of scores rather than traceback!
     start_vals_shape = [batch_size, 1]  # [bs, 1]
     all_preds_t = BK.constants_idx(start_vals_shape, 0).unsqueeze(
         -1)  # [bs, K, step], todo(note): start with 0!
     all_scores_t = BK.zeros(start_vals_shape).unsqueeze(
         -1)  # [bs, K, step]
     accu_scores_t = BK.zeros(start_vals_shape)  # [bs, K]
     arange_t = BK.arange_idx(batch_size).unsqueeze(-1)  # [bs, 1]
     # while loop
     prev_k = 1  # start with single one
     while not self.is_end(cur_step):
         # expand and score
         cache, scores_t, masks_t = self.step_score(
             cur_step, prev_k, cache)  # ..., [bs*pK, L], [bs*pK]
         scores_t_shape = BK.get_shape(scores_t)
         last_dim = scores_t_shape[-1]  # L
         # modify score to handle mask: keep previous pred for the masked items!
         sel_scores_t = BK.constants([batch_size, prev_k, last_dim],
                                     1.)  # [bs, pk, L]
         sel_scores_t.scatter_(-1, all_preds_t[:, :, -1:],
                               -1)  # [bs, pk, L]
         sel_scores_t = scores_t + _NEG_INF * (
             sel_scores_t.view(scores_t_shape) *
             (1. - masks_t).unsqueeze(-1))  # [bs*pK, L]
         # first select topk locally, note: here no need to sort!
         local_k = min(last_dim, beam_k)
         l_topk_scores, l_topk_idxes = sel_scores_t.topk(
             local_k, -1, sorted=False)  # [bs*pK, lK]
         # then topk globally on full pK*K
         add_score_shape = [batch_size, prev_k, local_k]
         to_sel_shape = [batch_size, prev_k * local_k]
         global_k = min(to_sel_shape[-1], beam_k)  # new k
         to_sel_scores, to_sel_idxes = \
             (l_topk_scores.view(add_score_shape) + accu_scores_t.unsqueeze(-1)).view(to_sel_shape), \
             l_topk_idxes.view(to_sel_shape)  # [bs, pK*lK]
         _, g_topk_idxes = to_sel_scores.topk(global_k, -1,
                                              sorted=True)  # [bs, gK]
         # get to know the idxes
         new_preds_t = to_sel_idxes.gather(-1, g_topk_idxes)  # [bs, gK]
         new_pk_idxes = (
             g_topk_idxes // local_k
         )  # which previous idx (in beam) are selected? [bs, gK]
         # get current pred and scores (handling mask)
         scores_t3 = scores_t.view([batch_size, -1,
                                    last_dim])  # [bs, pK, L]
         masks_t2 = masks_t.view([batch_size, -1])  # [bs, pK]
         new_masks_t = masks_t2[arange_t, new_pk_idxes]  # [bs, gK]
         # -- one-step score for new selections: [bs, gK], note: zero scores for masked ones
         new_scores_t = scores_t3[arange_t, new_pk_idxes,
                                  new_preds_t] * new_masks_t  # [bs, gK]
         # ending
         new_arrange_idxes = (arange_t * prev_k + new_pk_idxes).view(
             -1)  # [bs*gK]
         cache.arrange_idxes(new_arrange_idxes)
         self.step_end(cur_step, global_k, cache,
                       new_preds_t.view(-1))  # modify in cache
         # prepare next & judge ending
         all_preds_t = BK.concat([
             all_preds_t[arange_t, new_pk_idxes],
             new_preds_t.unsqueeze(-1)
         ], -1)  # [bs, gK, step]
         all_scores_t = BK.concat([
             all_scores_t[arange_t, new_pk_idxes],
             new_scores_t.unsqueeze(-1)
         ], -1)  # [bs, gK, step]
         accu_scores_t = accu_scores_t[
             arange_t, new_pk_idxes] + new_scores_t  # [bs, gK]
         prev_k = global_k  # for next step
         cur_step += 1
     # --
     # sort and ret at a final step
     _, final_idxes = accu_scores_t.topk(prev_k, -1, sorted=True)  # [bs, K]
     ret_preds = all_preds_t[
         arange_t, final_idxes][:, :,
                                1:]  # [bs, K, steps], exclude dummy start!
     ret_scores = all_scores_t[arange_t, final_idxes][:, :,
                                                      1:]  # [bs, K, steps]
     if ret_best:
         return ret_preds[:, 0], ret_scores[:, 0]  # [bs, slen]
     else:
         return ret_preds, ret_scores  # [bs, topk, slen]
Ejemplo n.º 14
0
 def gather_losses(self,
                   scores: List[BK.Expr],
                   label_t: BK.Expr,
                   valid_t: BK.Expr,
                   loss_neg_sample: float = None):
     conf: ZLabelConf = self.conf
     _loss_do_sel = conf.loss_do_sel
     _alpha_binary, _alpha_full = conf.loss_binary_alpha, conf.loss_full_alpha
     _alpha_all_binary = conf.loss_allbinary_alpha
     # --
     if self.crf is not None:  # CRF mode!
         assert _alpha_binary <= 0. and _alpha_all_binary <= 0.
         # reshape them into 3d
         valid_premask = (valid_t.sum(-1) > 0.)  # [bs, ...]
         # note: simply collect them all
         rets = []
         _pm_mask, _pm_label = valid_t[valid_premask], label_t[
             valid_premask]  # [??, slen]
         for score_t in scores:
             _one_pm_score = score_t[valid_premask]  # [??, slen, D]
             _one_fscore_t, _ = self._get_score(
                 _one_pm_score)  # [??, slen, L]
             # --
             # todo(+N): hacky fix, make it a leading NIL
             _pm_mask2 = _pm_mask.clone()
             _pm_mask2[:, 0] = 1.
             # --
             _one_loss, _one_count = self.crf.loss(_one_fscore_t, _pm_mask2,
                                                   _pm_label)  # ??
             rets.append((_one_loss * _alpha_full, _one_count))
     else:
         pos_t = (label_t > 0).float()  # 0 as NIL!!
         loss_mask_t = self._get_loss_mask(
             pos_t, valid_t, loss_neg_sample=loss_neg_sample)  # [bs, ...]
         if _loss_do_sel:
             _sel_mask = (loss_mask_t > 0.)  # [bs, ...]
             _sel_label = label_t[_sel_mask]  # [??]
             _sel_mask2 = BK.constants([len(_sel_label)], 1.)  # [??]
         # note: simply collect them all
         rets = []
         for score_t in scores:
             if _loss_do_sel:  # [??, ]
                 one_score_t, one_mask_t, one_label_t = score_t[
                     _sel_mask], _sel_mask2, _sel_label
             else:  # [bs, ..., D]
                 one_score_t, one_mask_t, one_label_t = score_t, loss_mask_t, label_t
             one_fscore_t, one_nilscore_t = self._get_score(one_score_t)
             # full loss
             one_loss_t = BK.loss_nll(one_fscore_t,
                                      one_label_t) * _alpha_full  # [????]
             # binary loss
             if _alpha_binary > 0.:  # plus ...
                 _binary_loss = BK.loss_binary(
                     one_nilscore_t.squeeze(-1),
                     (one_label_t > 0).float()) * _alpha_binary  # [???]
                 one_loss_t = one_loss_t + _binary_loss
             # all binary
             if _alpha_all_binary > 0.:  # plus ...
                 _tmp_label_t = BK.zeros(
                     BK.get_shape(one_fscore_t))  # [???, L]
                 _tmp_label_t.scatter_(-1, one_label_t.unsqueeze(-1), 1.)
                 _ab_loss = BK.loss_binary(
                     one_fscore_t,
                     _tmp_label_t) * _alpha_all_binary  # [???, L]
                 one_loss_t = one_loss_t + _ab_loss[..., 1:].sum(-1)
             # --
             one_loss_t = one_loss_t * one_mask_t
             rets.append((one_loss_t, one_mask_t))  # tuple(loss, mask)
     return rets
Ejemplo n.º 15
0
 def score_all(self,
               expr_main: BK.Expr,
               expr_pair: BK.Expr,
               input_mask: BK.Expr,
               gold_idxes: BK.Expr,
               local_normalize: bool = None,
               use_bigram: bool = True,
               extra_score: BK.Expr = None):
     conf: SeqLabelerConf = self.conf
     # first collect basic scores
     if conf.use_seqdec:
         # first prepare init hidden
         sd_init_t = self.prepare_sd_init(expr_main, expr_pair)  # [*, hid]
         # init cache: no mask at batch level
         sd_cache = self.seqdec.go_init(
             sd_init_t, init_mask=None)  # and no need to cum_state here!
         # prepare inputs at once
         if conf.sd_skip_non:
             gold_valid_mask = (gold_idxes > 0).float(
             ) * input_mask  # [*, slen], todo(note): fix 0 as non here!
             gv_idxes, gv_masks = BK.mask2idx(gold_valid_mask)  # [*, ?]
             bsize = BK.get_shape(gold_idxes, 0)
             arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
             # select and forward
             gv_embeds = self.laber.lookup(
                 gold_idxes[arange_t, gv_idxes])  # [*, ?, E]
             gv_input_t = self.sd_input_aff(
                 [expr_main[arange_t, gv_idxes], gv_embeds])  # [*, ?, hid]
             gv_hid_t = self.seqdec.go_feed(sd_cache, gv_input_t,
                                            gv_masks)  # [*, ?, hid]
             # select back and output_aff
             aug_hid_t = BK.concat([sd_init_t.unsqueeze(-2), gv_hid_t],
                                   -2)  # [*, 1+?, hid]
             sel_t = BK.pad(gold_valid_mask[:, :-1].cumsum(-1), (1, 0),
                            value=0.).long()  # [*, 1+(slen-1)]
             shifted_hid_t = aug_hid_t[arange_t, sel_t]  # [*, slen, hid]
         else:
             gold_idx_embeds = self.laber.lookup(gold_idxes)  # [*, slen, E]
             all_input_t = self.sd_input_aff(
                 [expr_main,
                  gold_idx_embeds])  # inputs to dec, [*, slen, hid]
             all_hid_t = self.seqdec.go_feed(
                 sd_cache, all_input_t,
                 input_mask)  # output-hids, [*, slen, hid]
             shifted_hid_t = BK.concat(
                 [sd_init_t.unsqueeze(-2), all_hid_t[:, :-1]],
                 -2)  # [*, slen, hid]
         # scorer
         pre_labeler_t = self.sd_output_aff([expr_main, shifted_hid_t
                                             ])  # [*, slen, hid]
     else:
         pre_labeler_t = expr_main  # [*, slen, Dm']
     # score with labeler (no norm here since we may need to add other scores)
     scores_t = self.laber.score(
         pre_labeler_t,
         None if expr_pair is None else expr_pair.unsqueeze(-2),
         input_mask,
         extra_score=extra_score,
         local_normalize=False)  # [*, slen, L]
     # bigram score addition
     if conf.use_bigram and use_bigram:
         bigram_scores_t = self.bigram.get_matrix()[
             gold_idxes[:, :-1]]  # [*, slen-1, L]
         score_shape = BK.get_shape(bigram_scores_t)
         score_shape[1] = 1
         slice_t = BK.constants(
             score_shape,
             0.)  # fix 0., no transition from BOS (and EOS) for simplicity!
         bigram_scores_t = BK.concat([slice_t, bigram_scores_t],
                                     1)  # [*, slen, L]
         scores_t += bigram_scores_t  # [*, slen]
     # local normalization?
     scores_t = self.laber.output_score(scores_t, local_normalize)
     return scores_t