def loss_cf(self, cf_scores: List[BK.Expr], insts, loss_cf: float): conf: SeqExitHelperConf = self.conf # -- assert self.is_cf # get oracle oracles = [self.cf_oracle_f(ff) for ff in insts] # bs*[NL, slen] or bs*[NL] rets = [] mask_t = BK.input_real( DataPadder.lengths2mask([len(z.sent) for z in insts])) # [bs, slen] for one_li, one_scores in enumerate(cf_scores): if conf.cf_use_seq: one_oracle_t = BK.input_real([z[one_li] for z in oracles]) # [bs] one_oracle_t *= conf.cf_scale one_mask_t = BK.zeros([len(one_oracle_t)]) + 1 else: one_oracle_t = BK.input_real( DataPadder.go_batch_2d([z[one_li] for z in oracles], 1.)) # [bs, slen] one_mask_t = (BK.rand(one_oracle_t.shape) >= ((one_oracle_t**conf.cf_loss_discard_curve) * conf.cf_loss_discard)) * mask_t one_oracle_t *= conf.cf_scale # simple L2 loss one_loss_t = (one_scores.squeeze(-1) - one_oracle_t)**2 one_loss_item = LossHelper.compile_leaf_loss(f"cf{one_li}", (one_loss_t * one_mask_t).sum(), one_mask_t.sum(), loss_lambda=loss_cf) rets.append(one_loss_item) return rets
def __init__(self, ibatch: InputBatch, IDX_PAD: int): # preps self.bsize = len(ibatch) self.arange1_t = BK.arange_idx(self.bsize) # [bsize] self.arange2_t = self.arange1_t.unsqueeze(-1) # [bsize, 1] self.arange3_t = self.arange2_t.unsqueeze(-1) # [bsize, 1, 1] # batched them all_seq_infos = [z.seq_info for z in ibatch.items] # enc: [*, len_enc]: ids(pad IDX_PAD), masks, segids(pad 0) self.enc_input_ids = BK.input_idx( DataPadder.go_batch_2d([z.enc_input_ids for z in all_seq_infos], int(IDX_PAD))) self.enc_input_masks = BK.input_real( DataPadder.lengths2mask( [len(z.enc_input_ids) for z in all_seq_infos])) self.enc_input_segids = BK.input_idx( DataPadder.go_batch_2d([z.enc_input_segids for z in all_seq_infos], 0)) # dec: [*, len_dec]: sel_idxes(pad 0), sel_lens(pad 1), masks, sent_idxes(pad ??) self.dec_sel_idxes = BK.input_idx( DataPadder.go_batch_2d([z.dec_sel_idxes for z in all_seq_infos], 0)) self.dec_sel_lens = BK.input_idx( DataPadder.go_batch_2d([z.dec_sel_lens for z in all_seq_infos], 1)) self.dec_sel_masks = BK.input_real( DataPadder.lengths2mask( [len(z.dec_sel_idxes) for z in all_seq_infos])) _max_dec_len = BK.get_shape(self.dec_sel_masks, 1) _dec_offsets = BK.input_idx( DataPadder.go_batch_2d([z.dec_offsets for z in all_seq_infos], _max_dec_len)) # note: CLS as -1, then 0,1,2,..., PAD gets -2! self.dec_sent_idxes = \ (BK.arange_idx(_max_dec_len).unsqueeze(0).unsqueeze(-1) >= _dec_offsets.unsqueeze(-2)).sum(-1).long() - 1 self.dec_sent_idxes[self.dec_sel_masks <= 0.] = -2 # dec -> enc: [*, len_enc] (calculated on needed!) # note: require 1-to-1 mapping (except pads)!! self._enc_back_hits = None self._enc_back_sel_idxes = None