Exemple #1
0
 def loss_cf(self, cf_scores: List[BK.Expr], insts, loss_cf: float):
     conf: SeqExitHelperConf = self.conf
     # --
     assert self.is_cf
     # get oracle
     oracles = [self.cf_oracle_f(ff)
                for ff in insts]  # bs*[NL, slen] or bs*[NL]
     rets = []
     mask_t = BK.input_real(
         DataPadder.lengths2mask([len(z.sent)
                                  for z in insts]))  # [bs, slen]
     for one_li, one_scores in enumerate(cf_scores):
         if conf.cf_use_seq:
             one_oracle_t = BK.input_real([z[one_li]
                                           for z in oracles])  # [bs]
             one_oracle_t *= conf.cf_scale
             one_mask_t = BK.zeros([len(one_oracle_t)]) + 1
         else:
             one_oracle_t = BK.input_real(
                 DataPadder.go_batch_2d([z[one_li] for z in oracles],
                                        1.))  # [bs, slen]
             one_mask_t = (BK.rand(one_oracle_t.shape) >=
                           ((one_oracle_t**conf.cf_loss_discard_curve) *
                            conf.cf_loss_discard)) * mask_t
             one_oracle_t *= conf.cf_scale
         # simple L2 loss
         one_loss_t = (one_scores.squeeze(-1) - one_oracle_t)**2
         one_loss_item = LossHelper.compile_leaf_loss(f"cf{one_li}",
                                                      (one_loss_t *
                                                       one_mask_t).sum(),
                                                      one_mask_t.sum(),
                                                      loss_lambda=loss_cf)
         rets.append(one_loss_item)
     return rets
Exemple #2
0
 def batched_rev_idxes(self):
     if self._batched_rev_idxes is None:
         padder = DataPadder(2, pad_vals=0)  # again pad 0
         batched_rev_idxes, _ = padder.pad([
             s.align_info.split2orig for s in self.seq_subs
         ])  # [bsize, sub_len]
         self._batched_rev_idxes = BK.input_idx(batched_rev_idxes)
     return self._batched_rev_idxes  # [bsize, sub_len]
Exemple #3
0
 def _transform_factors(self, factors: Union[List[List[int]], BK.Expr],
                        is_orig: bool, PAD_IDX: Union[int, float]):
     if isinstance(factors, BK.Expr):  # already padded
         batched_ids = factors
     else:
         padder = DataPadder(2, pad_vals=PAD_IDX)
         batched_ids, _ = padder.pad(factors)
         batched_ids = BK.input_idx(
             batched_ids)  # [bsize, orig-len if is_orig else sub_len]
     if is_orig:  # map to subtoks
         final_batched_ids = batched_ids[
             self.arange2_t, self.batched_rev_idxes]  # [bsize, sub_len]
     else:
         final_batched_ids = batched_ids  # [bsize, sub_len]
     return final_batched_ids
Exemple #4
0
 def __init__(self, name: str, voc: SimpleVocab):
     super().__init__(name)
     # --
     seq_name = "seq_" + name
     self.idx_get_f = lambda x: getattr(x, seq_name).idxes
     self.voc = voc
     self.idx_mask = voc.mask  # use this one!
     self.padder = DataPadder(2, pad_vals=voc.pad)  # dim=2, pad=[pad]
Exemple #5
0
 def _input_bert(self, insts: List[Sent]):
     bi = self.berter.create_input_batch_from_sents(insts)
     mask_expr = BK.input_real(
         DataPadder.lengths2mask([len(z) for z in insts]))  # [bs, slen, *]
     vstate = self.idec_manager.new_vstate(
         None, mask_expr)  # todo(+N): currently ignore emb layer!
     bert_expr = self.berter.forward(bi, vstate=vstate)
     return mask_expr, bert_expr, vstate
Exemple #6
0
 def prepare_inputs(self, insts: List):
     ret = OrderedDict()
     # first basic masks
     mask_arr = DataPadder.lengths2mask([len(z) for z in insts])
     ret["mask"] = BK.input_real(mask_arr)
     ret["mask_arr"] = mask_arr
     # then the rest
     for key, inputter in self.inputters.items():
         ret[key] = inputter.prepare(insts)
     return ret
Exemple #7
0
 def __init__(self, ibatch: InputBatch, IDX_PAD: int):
     # preps
     self.bsize = len(ibatch)
     self.arange1_t = BK.arange_idx(self.bsize)  # [bsize]
     self.arange2_t = self.arange1_t.unsqueeze(-1)  # [bsize, 1]
     self.arange3_t = self.arange2_t.unsqueeze(-1)  # [bsize, 1, 1]
     # batched them
     all_seq_infos = [z.seq_info for z in ibatch.items]
     # enc: [*, len_enc]: ids(pad IDX_PAD), masks, segids(pad 0)
     self.enc_input_ids = BK.input_idx(
         DataPadder.go_batch_2d([z.enc_input_ids for z in all_seq_infos],
                                int(IDX_PAD)))
     self.enc_input_masks = BK.input_real(
         DataPadder.lengths2mask(
             [len(z.enc_input_ids) for z in all_seq_infos]))
     self.enc_input_segids = BK.input_idx(
         DataPadder.go_batch_2d([z.enc_input_segids for z in all_seq_infos],
                                0))
     # dec: [*, len_dec]: sel_idxes(pad 0), sel_lens(pad 1), masks, sent_idxes(pad ??)
     self.dec_sel_idxes = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_sel_idxes for z in all_seq_infos],
                                0))
     self.dec_sel_lens = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_sel_lens for z in all_seq_infos], 1))
     self.dec_sel_masks = BK.input_real(
         DataPadder.lengths2mask(
             [len(z.dec_sel_idxes) for z in all_seq_infos]))
     _max_dec_len = BK.get_shape(self.dec_sel_masks, 1)
     _dec_offsets = BK.input_idx(
         DataPadder.go_batch_2d([z.dec_offsets for z in all_seq_infos],
                                _max_dec_len))
     # note: CLS as -1, then 0,1,2,..., PAD gets -2!
     self.dec_sent_idxes = \
         (BK.arange_idx(_max_dec_len).unsqueeze(0).unsqueeze(-1) >= _dec_offsets.unsqueeze(-2)).sum(-1).long() - 1
     self.dec_sent_idxes[self.dec_sel_masks <= 0.] = -2
     # dec -> enc: [*, len_enc] (calculated on needed!)
     # note: require 1-to-1 mapping (except pads)!!
     self._enc_back_hits = None
     self._enc_back_sel_idxes = None
Exemple #8
0
 def __init__(self, berter: BertEncoder,
              seq_subs: List[InputSubwordSeqField]):
     self.seq_subs = seq_subs
     self.berter = berter
     self.bsize = len(seq_subs)
     self.arange1_t = BK.arange_idx(self.bsize)  # [bsize]
     self.arange2_t = self.arange1_t.unsqueeze(-1)  # [bsize, 1]
     self.arange3_t = self.arange2_t.unsqueeze(-1)  # [bsize, 1, 1]
     # --
     tokenizer = self.berter.tokenizer
     PAD_IDX = tokenizer.pad_token_id
     # MASK_IDX = tokenizer.mask_token_id
     # CLS_IDX_l = [tokenizer.cls_token_id]
     # SEP_IDX_l = [tokenizer.sep_token_id]
     # make batched idxes
     padder = DataPadder(2, pad_vals=PAD_IDX, mask_range=2)
     batched_sublens = [len(s.idxes) for s in seq_subs]  # [bsize]
     batched_input_ids, batched_input_mask = padder.pad(
         [s.idxes for s in seq_subs])  # [bsize, sub_len]
     self.batched_sublens_p1 = BK.input_idx(
         batched_sublens
     ) + 1  # also the idx of EOS (if counting including BOS)
     self.batched_input_ids = BK.input_idx(batched_input_ids)
     self.batched_input_mask = BK.input_real(batched_input_mask)
     # make batched mappings (sub->orig)
     padder2 = DataPadder(2, pad_vals=0,
                          mask_range=2)  # pad as 0 to avoid out-of-range
     batched_first_idxes, batched_first_mask = padder2.pad(
         [s.align_info.orig2begin for s in seq_subs])  # [bsize, orig_len]
     self.batched_first_idxes = BK.input_idx(batched_first_idxes)
     self.batched_first_mask = BK.input_real(batched_first_mask)
     # reversed batched_mappings (orig->sub) (created when needed)
     self._batched_rev_idxes = None  # [bsize, sub_len]
     # --
     self.batched_repl_masks = None  # [bsize, sub_len], to replace with MASK
     self.batched_token_type_ids = None  # [bsize, 1+sub_len+1]
     self.batched_position_ids = None  # [bsize, 1+sub_len+1]
     self.other_factors = {}  # name -> aug_batched_ids
Exemple #9
0
 def _input_bert(self, insts: List[Sent]):
     bi = self.berter.create_input_batch_from_sents(insts)
     mask_expr = BK.input_real(
         DataPadder.lengths2mask([len(z) for z in insts]))  # [bs, slen, *]
     bert_expr = self.berter.forward(bi)
     return mask_expr, bert_expr
Exemple #10
0
 def __init__(self, name: str, voc: SimpleVocab):
     super().__init__(name + "_char", voc)
     # rewrite default ones
     seq_name = "seq_" + name
     self.idx_get_f = lambda x: getattr(x, seq_name).get_char_seq().idxes
     self.padder = DataPadder(3, pad_vals=voc.pad)  # dim=3, pad=[pad]
Exemple #11
0
 def prepare_batch(insts: List, idx_f: Callable, padder: DataPadder):
     cur_input_list = [idx_f(z) for z in insts]  # idx List
     cur_input_arr, _ = padder.pad(cur_input_list)
     ret_expr = BK.input_idx(cur_input_arr)
     return ret_expr