Beispiel #1
0
 def forward(self,
             expr_t: BK.Expr,
             mask_t: BK.Expr,
             scores_t=None,
             **kwargs):
     conf: IdecConnectorAttConf = self.conf
     # --
     # prepare input
     _d_bs, _dq, _dk, _d_nl, _d_nh = BK.get_shape(scores_t)
     in1_t = scores_t[:, :, :, self.lstart:, :self.head_end].reshape(
         [_d_bs, _dq, _dk, self.d_in])  # [*, lenq, lenk, din]
     in2_t = in1_t.transpose(-3, -2)  # [*, lenk, lenq, din]
     final_input_t = BK.concat([in1_t, in2_t], -1)  # [*, lenk, lenq, din*2]
     # forward
     node_ret_t = self.node.forward(final_input_t, mask_t, self.feed_output,
                                    self.lidx,
                                    **kwargs)  # [*, lenq, lenk, head_end]
     if self.feed_output:
         # pad zeros if necessary
         if self.head_end < _d_nh:
             pad_t = BK.zeros([_d_bs, _dq, _dk, _d_nh - self.head_end])
             node_ret_t = BK.concat([node_ret_t, pad_t],
                                    -3)  # [*, lenq, lenk, Hin]
         return node_ret_t
     else:
         return None
Beispiel #2
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PlainInputEmbedderConf = self.conf
     # --
     voc = self.voc
     input_t = BK.input_idx(inputs)  # [*, len]
     # rare unk in training
     if self.is_training() and self.use_rare_unk:
         rare_unk_rate = conf.rare_unk_rate
         cur_unk_imask = (
             self.rare_unk_mask[input_t] *
             (BK.rand(BK.get_shape(input_t)) < rare_unk_rate)).long()
         input_t = input_t * (1 - cur_unk_imask) + voc.unk * cur_unk_imask
     # bos and eos
     all_input_slices = []
     slice_shape = BK.get_shape(input_t)[:-1] + [1]
     if add_bos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.bos, dtype=input_t.dtype))
     all_input_slices.append(input_t)  # [*, len]
     if add_eos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.eos, dtype=input_t.dtype))
     final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
     # finally
     ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Beispiel #3
0
 def forward(self, input_map: Dict):
     mask_expr, expr_map = self.eg.forward(input_map)
     exprs = list(expr_map.values())  # follow the order in OrderedDict
     # concat and final
     concat_expr = BK.concat(exprs, -1)  # [*, len, SUM]
     final_expr = self.final_layer(concat_expr)
     return mask_expr, final_expr
Beispiel #4
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: PosiInputEmbedderConf = self.conf
     # --
     try:
         # input is a shape as prepared by "PosiHelper"
         batch_size, max_len = inputs
         if add_bos:
             max_len += 1
         if add_eos:
             max_len += 1
         posi_idxes = BK.arange_idx(max_len)  # [?len?]
         ret = self.E(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1)
     except:
         # input is tensor
         posi_idxes = BK.input_idx(inputs)  # [*, len]
         cur_maxlen = BK.get_shape(posi_idxes, -1)
         # --
         all_input_slices = []
         slice_shape = BK.get_shape(posi_idxes)[:-1] + [1]
         if add_bos:  # add 0 and offset
             all_input_slices.append(
                 BK.constants(slice_shape, 0, dtype=posi_idxes.dtype))
             cur_maxlen += 1
             posi_idxes += 1
         all_input_slices.append(posi_idxes)  # [*, len]
         if add_eos:
             all_input_slices.append(
                 BK.constants(slice_shape,
                              cur_maxlen,
                              dtype=posi_idxes.dtype))
         final_input_t = BK.concat(all_input_slices, -1)  # [*, 1?+len+1?]
         # finally
         ret = self.E(final_input_t)  # [*, ??, dim]
     return ret
Beispiel #5
0
 def ss_add_new_layer(self, layer_idx: int, expr: BK.Expr):
     assert layer_idx == self._cur_layer_idx
     self._cur_layer_idx += 1
     # --
     if self.states is None:
         self.states = []
     _cur_cum_state = (layer_idx in self.cum_state_lset)
     added_expr = expr
     if len(self.states) == layer_idx:  # first group of calls
         if _cur_cum_state:  # direct select!
             added_expr = added_expr[self._arange2_t, self._arange_sel_t]
         self.states.append(added_expr)  # directly add as first-time adding
     else:
         prev_state_all = self.states[layer_idx]  # [bsize, old_step]
         if _cur_cum_state:  # concat last and select
             added_expr = BK.concat([prev_state_all[:, -1].unsqueeze(1), added_expr], 1)[self._arange2_t, self._arange_sel_t]
         self.states[layer_idx] = BK.concat([prev_state_all, added_expr], 1)  # [*, old+new, D]
     return added_expr, self.states[layer_idx]  # q; kv
Beispiel #6
0
 def forward(self, med: ZMediator):
     conf: Idec2Conf = self.conf
     cur_lidx = med.lidx
     assert cur_lidx == self.max_app_lidx
     seq_info = med.ibatch.seq_info
     # --
     # get att values
     # todo(+N): modify med to make (cached) things more flexible!
     v_att_final = med.get_cache(self.gatt_key)
     if v_att_final is None:
         v_att = BK.concat(
             med.get_enc_cache(
                 conf.gatt_name).vals[self.min_gatt_lidx:cur_lidx],
             1)  # [*, L*H, h, m]
         v_att_rm = self.dsel_rm(v_att.permute(0, 3, 2, 1),
                                 seq_info)  # first reduce m: [*,m',h,L*H]
         v_att_rh = self.dsel_rh(v_att_rm.transpose(1, 2),
                                 seq_info)  # then reduce h: [*,h',m',L*H]
         v_att_final = BK.concat(
             [v_att_rh, v_att_rh.transpose(1, 2)],
             -1)  # final concat: [*,h',m',L*H*2]
         med.set_cache(self.gatt_key, v_att_final)
     hid_inputs = [self.gatt_drop(v_att_final)]
     # --
     # get hid values
     if conf.ghid_m or conf.ghid_h:
         _dsel = self.dsel_hid
         v_hid = med.get_enc_cache_val(  # [*, len', D]
             "hid",
             signature=_dsel.signature,
             function=(lambda x: _dsel.forward(x, seq_info)))
         if conf.ghid_h:
             hid_inputs.append(v_hid.unsqueeze(-2))  # [*, h, 1, D]
         if conf.ghid_m:
             hid_inputs.append(v_hid.unsqueeze(-3))  # [*, 1, m, D]
     # --
     # go
     ret = self.aff_hid(hid_inputs)
     if self.aff_final is not None:
         ret = self.aff_final(ret)
     return ret, None  # currently no feed!
Beispiel #7
0
 def _aug_ends(
     self, t: BK.Expr, BOS, PAD, EOS, dtype
 ):  # add BOS(CLS) and EOS(SEP) for a tensor (sub_len -> 1+sub_len+1)
     slice_shape = [self.bsize, 1]
     slices = [
         BK.constants(slice_shape, BOS, dtype=dtype), t,
         BK.constants(slice_shape, PAD, dtype=dtype)
     ]
     aug_batched_ids = BK.concat(slices, -1)  # [bsize, 1+sub_len+1]
     aug_batched_ids[self.arange1_t,
                     self.batched_sublens_p1] = EOS  # assign EOS
     return aug_batched_ids
Beispiel #8
0
 def forward(self, input_expr: BK.Expr, widx_expr: BK.Expr, wlen_expr: BK.Expr):
     conf: BaseSpanConf = self.conf
     # --
     # note: check empty, otherwise error
     input_item_shape = BK.get_shape(widx_expr)
     if np.prod(input_item_shape) == 0:
         return BK.zeros(input_item_shape + [self.output_dim])  # return an empty but shaped tensor
     # --
     start_idxes, end_idxes = widx_expr, widx_expr+wlen_expr  # make [start, end)
     # get sizes
     bsize, slen = BK.get_shape(input_expr)[:2]
     # num_span = BK.get_shape(start_idxes, 1)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
     # --
     reprs = []
     if conf.use_starts:  # start [start,
         reprs.append(input_expr[arange2_t, start_idxes])  # [bsize, ?, D]
     if conf.use_ends:  # simply ,end-1]
         reprs.append(input_expr[arange2_t, end_idxes-1])
     if conf.use_softhead:
         # expand range
         all_span_idxes, all_span_mask = expand_ranged_idxes(widx_expr, wlen_expr, 0, None)  # [bsize, ?, MW]
         # flatten
         flatten_all_span_idxes = all_span_idxes.view(bsize, -1)  # [bsize, ?*MW]
         flatten_all_span_mask = all_span_mask.view(bsize, -1)  # [bsize, ?*MW]
         # get softhead score (consider mask here)
         softhead_scores = self.softhead_scorer(input_expr).squeeze(-1)  # [bsize, slen]
         flatten_all_span_scores = softhead_scores[arange2_t, flatten_all_span_idxes]  # [bsize, ?*MW]
         flatten_all_span_scores += (1.-flatten_all_span_mask) * Constants.REAL_PRAC_MIN
         all_span_scores = flatten_all_span_scores.view(all_span_idxes.shape)  # [bsize, ?, MW]
         # reshape and (optionally topk) and softmax
         softhead_topk = conf.softhead_topk
         if softhead_topk>0 and BK.get_shape(all_span_scores,-1)>softhead_topk:  # further select topk; note: this may save mem
             final_span_score, _tmp_idxes = all_span_scores.topk(softhead_topk, dim=-1, sorted=False)  # [bsize, ?, K]
             final_span_idxes = all_span_idxes.gather(-1, _tmp_idxes)  # [bsize, ?, K]
         else:
             final_span_score, final_span_idxes = all_span_scores, all_span_idxes  # [bsize, ?, MW]
         final_prob = final_span_score.softmax(-1)  # [bsize, ?, ??]
         # [bsize, ?, ??, D]
         final_repr = input_expr[arange2_t, final_span_idxes.view(bsize, -1)].view(BK.get_shape(final_span_idxes)+[-1])
         weighted_repr = (final_repr * final_prob.unsqueeze(-1)).sum(-2)  # [bsize, ?, D]
         reprs.append(weighted_repr)
     if conf.use_width:
         cur_width_embed = self.width_embed(wlen_expr)  # [bsize, ?, DE]
         reprs.append(cur_width_embed)
     # concat
     concat_repr = BK.concat(reprs, -1)  # [bsize, ?, SUM]
     if conf.use_proj:
         ret = self.final_proj(concat_repr)  # [bsize, ?, DR]
     else:
         ret = concat_repr
     return ret
Beispiel #9
0
 def _split_extend(self, split_decisions: BK.Expr, cand_mask: BK.Expr):
     # first augment/pad split_decisions
     slice_ones = BK.constants([BK.get_shape(split_decisions, 0), 1],
                               1.)  # [*, 1]
     padded_split_decisions = BK.concat([slice_ones, split_decisions],
                                        -1)  # [*, clen]
     seg_cidxes, seg_masks = BK.mask2idx(
         padded_split_decisions)  # [*, seglen]
     # --
     cand_lens = cand_mask.sum(-1, keepdim=True).long()  # [*, 1]
     seg_masks *= (cand_lens > 0).float()  # for the case of no cands
     # --
     seg_cidxes_special = seg_cidxes + (1. - seg_masks).long(
     ) * cand_lens  # [*, seglen], fill in for paddings
     seg_cidxes_special2 = BK.concat([seg_cidxes_special, cand_lens],
                                     -1)  # [*, seglen+1]
     seg_clens = seg_cidxes_special2[:,
                                     1:] - seg_cidxes_special  # [*, seglen]
     # extend the idxes
     seg_ext_cidxes, seg_ext_masks = expand_ranged_idxes(
         seg_cidxes, seg_clens)  # [*, seglen, MW]
     seg_ext_masks *= seg_masks.unsqueeze(-1)
     return seg_ext_cidxes, seg_ext_masks, seg_masks  # 2x[*, seglen, MW], [*, seglen]
Beispiel #10
0
 def assign_boundaries(self, items: List, boundary_node,
                       flat_mask_t: BK.Expr, flat_hid_t: BK.Expr,
                       indicators: List):
     flat_indicators = boundary_node.prepare_indicators(
         indicators, BK.get_shape(flat_mask_t))
     # --
     _bsize, _dlen = BK.get_shape(flat_mask_t)  # [???, dlen]
     _once_bsize = max(1, int(self.conf.boundary_bsize / max(1, _dlen)))
     # --
     if _once_bsize >= _bsize:
         _, _left_idxes, _right_idxes = boundary_node.decode(
             flat_hid_t, flat_mask_t, flat_indicators)  # [???]
     else:
         _all_left_idxes, _all_right_idxes = [], []
         for ii in range(0, _bsize, _once_bsize):
             _, _one_left_idxes, _one_right_idxes = boundary_node.decode(
                 flat_hid_t[ii:ii + _once_bsize],
                 flat_mask_t[ii:ii + _once_bsize],
                 [z[ii:ii + _once_bsize] for z in flat_indicators])
             _all_left_idxes.append(_one_left_idxes)
             _all_right_idxes.append(_one_right_idxes)
         _left_idxes, _right_idxes = BK.concat(_all_left_idxes,
                                               0), BK.concat(
                                                   _all_right_idxes, 0)
     _arr_left, _arr_right = BK.get_value(_left_idxes), BK.get_value(
         _right_idxes)
     for ii, item in enumerate(items):
         _mention = item.mention
         _start = item._tmp_sstart  # need to minus this!!
         _left_widx, _right_widx = _arr_left[ii].item(
         ) - _start, _arr_right[ii].item() - _start
         # todo(+N): sometimes we can have repeated ones, currently simply over-write!
         if _mention.get_span()[1] == 1:
             _mention.set_span(*(_mention.get_span()),
                               shead=True)  # first move to shead!
         _mention.set_span(_left_widx, _right_widx - _left_widx + 1)
Beispiel #11
0
 def forward(self, med: ZMediator, **kwargs):
     conf: IdecConnectorAttConf = self.conf
     # --
     # get stack att: already transposed by zmed
     scores_t = med.get_stack_att()  # [*, len_q, len_k, NL, H]
     _d_bs, _dq, _dk, _d_nl, _d_nh = BK.get_shape(scores_t)
     in1_t = scores_t[:, :, :, self.lstart:, :self.head_end].reshape(
         [_d_bs, _dq, _dk, self.d_in])  # [*, lenq, lenk, din]
     in2_t = in1_t.transpose(-3, -2)  # [*, lenk, lenq, din]
     cat_t = self._go_detach(BK.concat([in1_t, in2_t],
                                       -1))  # [*, lenk, lenq, din*2]
     # further affine
     cat_drop_t = self.pre_mid_drop(cat_t)  # [*, lenk, lenq, din*2]
     ret_t = self.mid_aff(cat_drop_t)  # [*, lenk, lenq, M]
     return ret_t
Beispiel #12
0
 def s0_open_new_steps(self, bsize: int, ssize: int, mask: BK.Expr = None):
     assert ssize > 0
     assert self._cur_layer_idx == -1
     self._cur_layer_idx = 0
     # --
     new_mask = BK.constants([bsize, ssize], 1.) if mask is None else mask  # [*, ssize]
     # --
     # prepare for store_lstate selecting
     if len(self.cum_state_lset) > 0:  # any layer need to accumulat?
         self._arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [bsize, 1]
         # note: if no last state, simply clamp 0, otherwise, offset by 1 since we will concat later
         self._arange_sel_t = mask2posi_padded(new_mask, 0, 0) if mask is None else mask2posi_padded(new_mask, 1, 0)
     # prev_steps = self.steps  # previous accumulated steps
     self.steps += ssize
     self.mask = new_mask if self.mask is None else BK.concat([self.mask, new_mask], 1)  # [*, old+new]
     self.positions = mask2posi(self.mask, offset=-1, cmin=0)  # [*, old+new], recalculate!!
Beispiel #13
0
 def forward(self, inputs, add_bos=False, add_eos=False):
     conf: CharCnnInputEmbedderConf = self.conf
     # --
     voc = self.voc
     char_input_t = BK.input_idx(inputs)  # [*, len]
     # todo(note): no need for replacing to unk for char!!
     # bos and eos
     all_input_slices = []
     slice_shape = BK.get_shape(char_input_t)
     slice_shape[-2] = 1  # [*, 1, clen]
     if add_bos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.bos, dtype=char_input_t.dtype))
     all_input_slices.append(char_input_t)  # [*, len, clen]
     if add_eos:
         all_input_slices.append(
             BK.constants(slice_shape, voc.eos, dtype=char_input_t.dtype))
     final_input_t = BK.concat(all_input_slices, -2)  # [*, 1?+len+1?, clen]
     # char embeddings
     char_embed_expr = self.E(final_input_t)  # [*, ??, dim]
     # char cnn
     ret = self.cnn(char_embed_expr)
     return ret
Beispiel #14
0
 def forward(self, input_map: Dict):
     add_bos, add_eos = self.conf.add_bos, self.conf.add_eos
     ret = OrderedDict()  # [*, len, ?]
     for key, embedder_pack in self.embedders.items(
     ):  # according to REG order!!
         embedder, input_name = embedder_pack
         one_expr = embedder(input_map[input_name],
                             add_bos=add_bos,
                             add_eos=add_eos)
         ret[key] = one_expr
     # mask expr
     mask_expr = input_map.get("mask")
     if mask_expr is not None:
         all_input_slices = []
         mask_slice = BK.constants(BK.get_shape(mask_expr)[:-1] + [1],
                                   1,
                                   dtype=mask_expr.dtype)  # [*, 1]
         if add_bos:
             all_input_slices.append(mask_slice)
         all_input_slices.append(mask_expr)
         if add_eos:
             all_input_slices.append(mask_slice)
         mask_expr = BK.concat(all_input_slices, -1)  # [*, ?+len+?]
     return mask_expr, ret
Beispiel #15
0
 def beam_search(self, batch_size: int, beam_k: int, ret_best: bool = True):
     _NEG_INF = Constants.REAL_PRAC_MIN
     # --
     cur_step = 0
     cache: DecCache = None
     # init: keep the seq of scores rather than traceback!
     start_vals_shape = [batch_size, 1]  # [bs, 1]
     all_preds_t = BK.constants_idx(start_vals_shape, 0).unsqueeze(
         -1)  # [bs, K, step], todo(note): start with 0!
     all_scores_t = BK.zeros(start_vals_shape).unsqueeze(
         -1)  # [bs, K, step]
     accu_scores_t = BK.zeros(start_vals_shape)  # [bs, K]
     arange_t = BK.arange_idx(batch_size).unsqueeze(-1)  # [bs, 1]
     # while loop
     prev_k = 1  # start with single one
     while not self.is_end(cur_step):
         # expand and score
         cache, scores_t, masks_t = self.step_score(
             cur_step, prev_k, cache)  # ..., [bs*pK, L], [bs*pK]
         scores_t_shape = BK.get_shape(scores_t)
         last_dim = scores_t_shape[-1]  # L
         # modify score to handle mask: keep previous pred for the masked items!
         sel_scores_t = BK.constants([batch_size, prev_k, last_dim],
                                     1.)  # [bs, pk, L]
         sel_scores_t.scatter_(-1, all_preds_t[:, :, -1:],
                               -1)  # [bs, pk, L]
         sel_scores_t = scores_t + _NEG_INF * (
             sel_scores_t.view(scores_t_shape) *
             (1. - masks_t).unsqueeze(-1))  # [bs*pK, L]
         # first select topk locally, note: here no need to sort!
         local_k = min(last_dim, beam_k)
         l_topk_scores, l_topk_idxes = sel_scores_t.topk(
             local_k, -1, sorted=False)  # [bs*pK, lK]
         # then topk globally on full pK*K
         add_score_shape = [batch_size, prev_k, local_k]
         to_sel_shape = [batch_size, prev_k * local_k]
         global_k = min(to_sel_shape[-1], beam_k)  # new k
         to_sel_scores, to_sel_idxes = \
             (l_topk_scores.view(add_score_shape) + accu_scores_t.unsqueeze(-1)).view(to_sel_shape), \
             l_topk_idxes.view(to_sel_shape)  # [bs, pK*lK]
         _, g_topk_idxes = to_sel_scores.topk(global_k, -1,
                                              sorted=True)  # [bs, gK]
         # get to know the idxes
         new_preds_t = to_sel_idxes.gather(-1, g_topk_idxes)  # [bs, gK]
         new_pk_idxes = (
             g_topk_idxes // local_k
         )  # which previous idx (in beam) are selected? [bs, gK]
         # get current pred and scores (handling mask)
         scores_t3 = scores_t.view([batch_size, -1,
                                    last_dim])  # [bs, pK, L]
         masks_t2 = masks_t.view([batch_size, -1])  # [bs, pK]
         new_masks_t = masks_t2[arange_t, new_pk_idxes]  # [bs, gK]
         # -- one-step score for new selections: [bs, gK], note: zero scores for masked ones
         new_scores_t = scores_t3[arange_t, new_pk_idxes,
                                  new_preds_t] * new_masks_t  # [bs, gK]
         # ending
         new_arrange_idxes = (arange_t * prev_k + new_pk_idxes).view(
             -1)  # [bs*gK]
         cache.arrange_idxes(new_arrange_idxes)
         self.step_end(cur_step, global_k, cache,
                       new_preds_t.view(-1))  # modify in cache
         # prepare next & judge ending
         all_preds_t = BK.concat([
             all_preds_t[arange_t, new_pk_idxes],
             new_preds_t.unsqueeze(-1)
         ], -1)  # [bs, gK, step]
         all_scores_t = BK.concat([
             all_scores_t[arange_t, new_pk_idxes],
             new_scores_t.unsqueeze(-1)
         ], -1)  # [bs, gK, step]
         accu_scores_t = accu_scores_t[
             arange_t, new_pk_idxes] + new_scores_t  # [bs, gK]
         prev_k = global_k  # for next step
         cur_step += 1
     # --
     # sort and ret at a final step
     _, final_idxes = accu_scores_t.topk(prev_k, -1, sorted=True)  # [bs, K]
     ret_preds = all_preds_t[
         arange_t, final_idxes][:, :,
                                1:]  # [bs, K, steps], exclude dummy start!
     ret_scores = all_scores_t[arange_t, final_idxes][:, :,
                                                      1:]  # [bs, K, steps]
     if ret_best:
         return ret_preds[:, 0], ret_scores[:, 0]  # [bs, slen]
     else:
         return ret_preds, ret_scores  # [bs, topk, slen]
Beispiel #16
0
 def _loss_feed_cand(self, mask_expr, cand_full_scores, pred_cand_decisions,
                     expr_seq_gaddr, expr_group_widxes, expr_group_masks,
                     expr_loss_weight_non):
     conf: SoftExtractorConf = self.conf
     bsize, slen = BK.get_shape(mask_expr)
     arange3_t = BK.arange_idx(bsize).unsqueeze(-1).unsqueeze(
         -1)  # [*, 1, 1]
     # --
     # step 1.1: bag loss
     cand_gold_mask = (expr_seq_gaddr >=
                       0).float() * mask_expr  # [*, slen], whether is-arg
     raw_loss_cand = BK.loss_binary(
         cand_full_scores,
         cand_gold_mask,
         label_smoothing=conf.cand_label_smoothing)  # [*, slen]
     # how to weight?
     extended_scores_t = cand_full_scores[arange3_t, expr_group_widxes] + (
         1. - expr_group_masks) * Constants.REAL_PRAC_MIN  # [*, slen, MW]
     if BK.is_zero_shape(extended_scores_t):
         extended_scores_max_t = BK.zeros(mask_expr.shape)  # [*, slen]
     else:
         extended_scores_max_t, _ = extended_scores_t.max(-1)  # [*, slen]
     _w_alpha = conf.cand_loss_weight_alpha
     _weight = ((cand_full_scores - extended_scores_max_t) *
                _w_alpha).exp()  # [*, slen]
     if not conf.cand_loss_div_max:  # div sum-all, like doing softmax
         _weight = _weight / (
             (extended_scores_t - extended_scores_max_t.unsqueeze(-1)) *
             _w_alpha).exp().sum(-1)
     _weight = _weight * (_weight >=
                          conf.cand_loss_weight_thresh).float()  # [*, slen]
     if conf.cand_detach_weight:
         _weight = _weight.detach()
     # pos poison (dis-encouragement)
     if conf.cand_loss_pos_poison:
         poison_loss = BK.loss_binary(
             cand_full_scores,
             1. - cand_gold_mask,
             label_smoothing=conf.cand_label_smoothing)  # [*, slen]
         raw_loss_cand = raw_loss_cand * _weight + poison_loss * cand_gold_mask * (
             1. - _weight)  # [*, slen]
     else:
         raw_loss_cand = raw_loss_cand * _weight
     # final weight it
     cand_loss_weights = BK.where(cand_gold_mask == 0.,
                                  expr_loss_weight_non.unsqueeze(-1) *
                                  conf.loss_weight_non,
                                  mask_expr)  # [*, slen]
     final_cand_loss_weights = cand_loss_weights * mask_expr  # [*, slen]
     loss_cand_item = LossHelper.compile_leaf_loss(
         f"cand", (raw_loss_cand * final_cand_loss_weights).sum(),
         final_cand_loss_weights.sum(),
         loss_lambda=conf.loss_cand)
     # step 1.2: feed cand
     # todo(+N): currently only pred/sample, whether adding certain teacher-forcing?
     sample_decisions = (BK.sigmoid(cand_full_scores) >= BK.rand(
         cand_full_scores.shape)).float() * mask_expr  # [*, slen]
     _use_sample_mask = (BK.rand([bsize])
                         <= conf.cand_feed_sample_rate).float().unsqueeze(
                             -1)  # [*, 1], seq-level
     feed_cand_decisions = (_use_sample_mask * sample_decisions +
                            (1. - _use_sample_mask) * pred_cand_decisions
                            )  # [*, slen]
     # next
     cand_widxes, cand_masks = BK.mask2idx(feed_cand_decisions)  # [*, clen]
     # --
     # extra: loss_cand_entropy
     rets = [loss_cand_item]
     _loss_cand_entropy = conf.loss_cand_entropy
     if _loss_cand_entropy > 0.:
         _prob = extended_scores_t.softmax(-1)  # [*, slen, MW]
         _ent = EntropyHelper.self_entropy(_prob)  # [*, slen]
         # [*, slen], only first one in bag
         _ent_mask = BK.concat([
             expr_seq_gaddr[:, :1] >= 0,
             expr_seq_gaddr[:, 1:] != expr_seq_gaddr[:, :-1]
         ], -1).float() * cand_gold_mask
         _loss_ent_item = LossHelper.compile_leaf_loss(
             f"cand_ent", (_ent * _ent_mask).sum(),
             _ent_mask.sum(),
             loss_lambda=_loss_cand_entropy)
         rets.append(_loss_ent_item)
     # --
     return rets, cand_widxes, cand_masks
Beispiel #17
0
 def loss(self,
          insts: Union[List[Sent], List[Frame]],
          input_expr: BK.Expr,
          mask_expr: BK.Expr,
          pair_expr: BK.Expr = None,
          lookup_flatten=False,
          external_extra_score: BK.Expr = None):
     conf: AnchorExtractorConf = self.conf
     assert not lookup_flatten
     bsize, slen = BK.get_shape(mask_expr)
     # --
     # step 0: prepare
     arr_items, expr_seq_gaddr, expr_seq_labs, expr_group_widxes, expr_group_masks, expr_loss_weight_non = \
         self.helper.prepare(insts, mlen=BK.get_shape(mask_expr, -1), use_cache=True)
     arange2_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
     arange3_t = arange2_t.unsqueeze(-1)  # [*, 1, 1]
     # --
     # step 1: label, simply scoring everything!
     _main_t, _pair_t = self.lab_node.transform_expr(input_expr, pair_expr)
     all_scores_t = self.lab_node.score_all(
         _main_t,
         _pair_t,
         mask_expr,
         None,
         local_normalize=False,
         extra_score=external_extra_score
     )  # unnormalized scores [*, slen, L]
     all_probs_t = all_scores_t.softmax(-1)  # [*, slen, L]
     all_gprob_t = all_probs_t.gather(-1,
                                      expr_seq_labs.unsqueeze(-1)).squeeze(
                                          -1)  # [*, slen]
     # how to weight
     extended_gprob_t = all_gprob_t[
         arange3_t, expr_group_widxes] * expr_group_masks  # [*, slen, MW]
     if BK.is_zero_shape(extended_gprob_t):
         extended_gprob_max_t = BK.zeros(mask_expr.shape)  # [*, slen]
     else:
         extended_gprob_max_t, _ = extended_gprob_t.max(-1)  # [*, slen]
     _w_alpha = conf.cand_loss_weight_alpha
     _weight = (
         (all_gprob_t * mask_expr) /
         (extended_gprob_max_t.clamp(min=1e-5)))**_w_alpha  # [*, slen]
     _label_smoothing = conf.lab_conf.labeler_conf.label_smoothing
     _loss1 = BK.loss_nll(all_scores_t,
                          expr_seq_labs,
                          label_smoothing=_label_smoothing)  # [*, slen]
     _loss2 = BK.loss_nll(all_scores_t,
                          BK.constants_idx([bsize, slen], 0),
                          label_smoothing=_label_smoothing)  # [*, slen]
     _weight1 = _weight.detach() if conf.detach_weight_lab else _weight
     _raw_loss = _weight1 * _loss1 + (1. - _weight1) * _loss2  # [*, slen]
     # final weight it
     cand_loss_weights = BK.where(expr_seq_labs == 0,
                                  expr_loss_weight_non.unsqueeze(-1) *
                                  conf.loss_weight_non,
                                  mask_expr)  # [*, slen]
     final_cand_loss_weights = cand_loss_weights * mask_expr  # [*, slen]
     loss_lab_item = LossHelper.compile_leaf_loss(
         f"lab", (_raw_loss * final_cand_loss_weights).sum(),
         final_cand_loss_weights.sum(),
         loss_lambda=conf.loss_lab,
         gold=(expr_seq_labs > 0).float().sum())
     # --
     # step 1.5
     all_losses = [loss_lab_item]
     _loss_cand_entropy = conf.loss_cand_entropy
     if _loss_cand_entropy > 0.:
         _prob = extended_gprob_t  # [*, slen, MW]
         _ent = EntropyHelper.self_entropy(_prob)  # [*, slen]
         # [*, slen], only first one in bag
         _ent_mask = BK.concat([expr_seq_gaddr[:,:1]>=0, expr_seq_gaddr[:,1:]!=expr_seq_gaddr[:,:-1]],-1).float() \
                     * (expr_seq_labs>0).float()
         _loss_ent_item = LossHelper.compile_leaf_loss(
             f"cand_ent", (_ent * _ent_mask).sum(),
             _ent_mask.sum(),
             loss_lambda=_loss_cand_entropy)
         all_losses.append(_loss_ent_item)
     # --
     # step 4: extend (select topk)
     if conf.loss_ext > 0.:
         if BK.is_zero_shape(extended_gprob_t):
             flt_mask = (BK.zeros(mask_expr.shape) > 0)
         else:
             _topk = min(conf.ext_loss_topk,
                         BK.get_shape(extended_gprob_t,
                                      -1))  # number to extract
             _topk_grpob_t, _ = extended_gprob_t.topk(
                 _topk, dim=-1)  # [*, slen, K]
             flt_mask = (expr_seq_labs >
                         0) & (all_gprob_t >= _topk_grpob_t.min(-1)[0]) & (
                             _weight > conf.ext_loss_thresh)  # [*, slen]
         flt_sidx = BK.arange_idx(bsize).unsqueeze(-1).expand_as(flt_mask)[
             flt_mask]  # [?]
         flt_expr = input_expr[flt_mask]  # [?, D]
         flt_full_expr = self._prepare_full_expr(flt_mask)  # [?, slen, D]
         flt_items = arr_items.flatten()[BK.get_value(
             expr_seq_gaddr[flt_mask])]  # [?]
         flt_weights = _weight.detach(
         )[flt_mask] if conf.detach_weight_ext else _weight[flt_mask]  # [?]
         loss_ext_item = self.ext_node.loss(flt_items,
                                            input_expr[flt_sidx],
                                            flt_expr,
                                            flt_full_expr,
                                            mask_expr[flt_sidx],
                                            flt_extra_weights=flt_weights)
         all_losses.append(loss_ext_item)
     # --
     # return loss
     ret_loss = LossHelper.combine_multiple_losses(all_losses)
     return ret_loss, None
Beispiel #18
0
 def forward(self, inputs, vstate: VrecSteppingState = None, inc_cls=False):
     conf: BertEncoderConf = self.conf
     # --
     no_bert_ft = (not conf.bert_ft
                   )  # whether fine-tune bert (if not detach hiddens!)
     impl = self.impl
     # --
     # prepare inputs
     if not isinstance(inputs, BerterInputBatch):
         inputs = self.create_input_batch(inputs)
     all_output_layers = []  # including embeddings
     # --
     # get embeddings (for embeddings, we simply forward once!)
     mask_repl_rate = conf.bert_repl_mask_rate if self.is_training() else 0.
     input_ids, input_masks = inputs.get_basic_inputs(
         mask_repl_rate)  # [bsize, 1+sub_len+1]
     other_embeds = None
     if self.other_embed_nodes is not None and len(
             self.other_embed_nodes) > 0:
         other_embeds = 0.
         for other_name, other_node in self.other_embed_nodes.items():
             other_embeds += other_node(
                 inputs.other_factors[other_name]
             )  # should be prepared correspondingly!!
     # --
     # forward layers (for layers, we may need to split!)
     # todo(+N): we simply split things apart, thus middle parts may lack CLS/SEP, and not true global att
     # todo(+N): the lengths currently are hard-coded!!
     MAX_LEN = 512  # max len
     INBUF_LEN = 50  # in-between buffer for splits, for both sides!
     cur_sub_len = BK.get_shape(input_ids, 1)  # 1+sub_len+1
     needs_split = (cur_sub_len > MAX_LEN)
     if needs_split:  # decide split and merge points
         split_points = self._calculate_split_points(
             cur_sub_len, MAX_LEN, INBUF_LEN)
         zwarn(
             f"Multi-seg for Berter: {cur_sub_len}//{len(split_points)}->{split_points}"
         )
     # --
     # todo(note): we also need split from embeddings
     if needs_split:
         all_embed_pieces = []
         split_extended_attention_mask = []
         for o_s, o_e, i_s, i_e in split_points:
             piece_embeddings, piece_extended_attention_mask = impl.forward_embedding(
                 *[(None if z is None else z[:, o_s:o_e]) for z in [
                     input_ids, input_masks, inputs.batched_token_type_ids,
                     inputs.batched_position_ids, other_embeds
                 ]])
             all_embed_pieces.append(piece_embeddings[:, i_s:i_e])
             split_extended_attention_mask.append(
                 piece_extended_attention_mask)
         embeddings = BK.concat(all_embed_pieces, 1)  # concat back to full
         extended_attention_mask = None
     else:
         embeddings, extended_attention_mask = impl.forward_embedding(
             input_ids, input_masks, inputs.batched_token_type_ids,
             inputs.batched_position_ids, other_embeds)
         split_extended_attention_mask = None
     if no_bert_ft:  # stop gradient
         embeddings = embeddings.detach()
     # --
     cur_hidden = embeddings
     all_output_layers.append(embeddings)  # *[bsize, 1+sub_len+1, D]
     # also prepare mapper idxes for sub <-> orig
     # todo(+N): currently only use the first sub-word!
     idxes_arange2 = inputs.arange2_t  # [bsize, 1]
     batched_first_idxes_p1 = (1 + inputs.batched_first_idxes) * (
         inputs.batched_first_mask.long())  # plus one for CLS offset!
     if inc_cls:  # [bsize, 1+orig_len]
         idxes_sub2orig = BK.concat([
             BK.constants_idx([inputs.bsize, 1], 0), batched_first_idxes_p1
         ], 1)
     else:  # [bsize, orig_len]
         idxes_sub2orig = batched_first_idxes_p1
     _input_masks0 = None  # used for vstate back, make it 0. for BOS and EOS
     # for ii in range(impl.num_hidden_layers):
     for ii in range(max(self.actual_output_layers)
                     ):  # do not need that much if does not require!
         # forward multiple times with splitting if needed
         if needs_split:
             all_pieces = []
             for piece_idx, piece_points in enumerate(split_points):
                 o_s, o_e, i_s, i_e = piece_points
                 piece_res = impl.forward_hidden(
                     ii, cur_hidden[:, o_s:o_e],
                     split_extended_attention_mask[piece_idx])[:, i_s:i_e]
                 all_pieces.append(piece_res)
             new_hidden = BK.concat(all_pieces, 1)  # concat back to full
         else:
             new_hidden = impl.forward_hidden(ii, cur_hidden,
                                              extended_attention_mask)
         if no_bert_ft:  # stop gradient
             new_hidden = new_hidden.detach()
         if vstate is not None:
             # from 1+sub_len+1 -> (inc_cls?)+orig_len
             new_hidden2orig = new_hidden[
                 idxes_arange2, idxes_sub2orig]  # [bsize, 1?+orig_len, D]
             # update
             new_hidden2orig_ret = vstate.update(
                 new_hidden2orig)  # [bsize, 1?+orig_len, D]
             if new_hidden2orig_ret is not None:
                 # calculate when needed
                 if _input_masks0 is None:  # [bsize, 1+sub_len+1, 1] with 1. only for real valid ones
                     _input_masks0 = inputs._aug_ends(
                         inputs.batched_input_mask, 0., 0., 0.,
                         BK.float32).unsqueeze(-1)
                 # back to 1+sub_len+1; todo(+N): here we simply add and //2, and no CLS back from orig to sub!!
                 tmp_orig2sub = new_hidden2orig_ret[
                     idxes_arange2,
                     int(inc_cls) +
                     inputs.batched_rev_idxes]  # [bsize, sub_len, D]
                 tmp_slice_size = BK.get_shape(tmp_orig2sub)
                 tmp_slice_size[1] = 1
                 tmp_slice_zero = BK.zeros(tmp_slice_size)
                 tmp_orig2sub_aug = BK.concat(
                     [tmp_slice_zero, tmp_orig2sub, tmp_slice_zero],
                     1)  # [bsize, 1+sub_len+1, D]
                 new_hidden = new_hidden * (1. - _input_masks0) + (
                     (new_hidden + tmp_orig2sub_aug) / 2.) * _input_masks0
         all_output_layers.append(new_hidden)
         cur_hidden = new_hidden
     # finally, prepare return
     final_output_layers = [
         all_output_layers[z] for z in conf.bert_output_layers
     ]  # *[bsize,1+sl+1,D]
     combined_output = self.combiner(
         final_output_layers)  # [bsize, 1+sl+1, ??]
     final_ret = combined_output[idxes_arange2,
                                 idxes_sub2orig]  # [bsize, 1?+orig_len, D]
     return final_ret
Beispiel #19
0
 def score_all(self,
               expr_main: BK.Expr,
               expr_pair: BK.Expr,
               input_mask: BK.Expr,
               gold_idxes: BK.Expr,
               local_normalize: bool = None,
               use_bigram: bool = True,
               extra_score: BK.Expr = None):
     conf: SeqLabelerConf = self.conf
     # first collect basic scores
     if conf.use_seqdec:
         # first prepare init hidden
         sd_init_t = self.prepare_sd_init(expr_main, expr_pair)  # [*, hid]
         # init cache: no mask at batch level
         sd_cache = self.seqdec.go_init(
             sd_init_t, init_mask=None)  # and no need to cum_state here!
         # prepare inputs at once
         if conf.sd_skip_non:
             gold_valid_mask = (gold_idxes > 0).float(
             ) * input_mask  # [*, slen], todo(note): fix 0 as non here!
             gv_idxes, gv_masks = BK.mask2idx(gold_valid_mask)  # [*, ?]
             bsize = BK.get_shape(gold_idxes, 0)
             arange_t = BK.arange_idx(bsize).unsqueeze(-1)  # [*, 1]
             # select and forward
             gv_embeds = self.laber.lookup(
                 gold_idxes[arange_t, gv_idxes])  # [*, ?, E]
             gv_input_t = self.sd_input_aff(
                 [expr_main[arange_t, gv_idxes], gv_embeds])  # [*, ?, hid]
             gv_hid_t = self.seqdec.go_feed(sd_cache, gv_input_t,
                                            gv_masks)  # [*, ?, hid]
             # select back and output_aff
             aug_hid_t = BK.concat([sd_init_t.unsqueeze(-2), gv_hid_t],
                                   -2)  # [*, 1+?, hid]
             sel_t = BK.pad(gold_valid_mask[:, :-1].cumsum(-1), (1, 0),
                            value=0.).long()  # [*, 1+(slen-1)]
             shifted_hid_t = aug_hid_t[arange_t, sel_t]  # [*, slen, hid]
         else:
             gold_idx_embeds = self.laber.lookup(gold_idxes)  # [*, slen, E]
             all_input_t = self.sd_input_aff(
                 [expr_main,
                  gold_idx_embeds])  # inputs to dec, [*, slen, hid]
             all_hid_t = self.seqdec.go_feed(
                 sd_cache, all_input_t,
                 input_mask)  # output-hids, [*, slen, hid]
             shifted_hid_t = BK.concat(
                 [sd_init_t.unsqueeze(-2), all_hid_t[:, :-1]],
                 -2)  # [*, slen, hid]
         # scorer
         pre_labeler_t = self.sd_output_aff([expr_main, shifted_hid_t
                                             ])  # [*, slen, hid]
     else:
         pre_labeler_t = expr_main  # [*, slen, Dm']
     # score with labeler (no norm here since we may need to add other scores)
     scores_t = self.laber.score(
         pre_labeler_t,
         None if expr_pair is None else expr_pair.unsqueeze(-2),
         input_mask,
         extra_score=extra_score,
         local_normalize=False)  # [*, slen, L]
     # bigram score addition
     if conf.use_bigram and use_bigram:
         bigram_scores_t = self.bigram.get_matrix()[
             gold_idxes[:, :-1]]  # [*, slen-1, L]
         score_shape = BK.get_shape(bigram_scores_t)
         score_shape[1] = 1
         slice_t = BK.constants(
             score_shape,
             0.)  # fix 0., no transition from BOS (and EOS) for simplicity!
         bigram_scores_t = BK.concat([slice_t, bigram_scores_t],
                                     1)  # [*, slen, L]
         scores_t += bigram_scores_t  # [*, slen]
     # local normalization?
     scores_t = self.laber.output_score(scores_t, local_normalize)
     return scores_t