예제 #1
0
파일: expr.py 프로젝트: zzsfornlp/zmsp
 def _arrange_idxes(slices):
     values, bidxes = [], []
     # tmp
     tmp_bidx_bases = [
         0,
     ]
     tmp_id2idx = {}
     for s in slices:
         one_ew, one_sidx = s.ew, s.slice_idx
         ew_id = one_ew.id
         if ew_id not in tmp_id2idx:
             tmp_id2idx[ew_id] = len(values)
             values.append(one_ew.val)
             tmp_bidx_bases.append(one_ew.bsize + tmp_bidx_bases[-1])
         #
         idx_in_vals = tmp_id2idx[ew_id]
         bidxes.append(tmp_bidx_bases[idx_in_vals] + one_sidx)
     # check for perfect match
     if Helper.check_is_range(bidxes, tmp_bidx_bases[-1]):
         bidxes = None
     return values, bidxes
예제 #2
0
 def arange_cache(self, bidxes):
     new_bsize = len(bidxes)
     # if the idxes are already fine, then no need to select
     if not Helper.check_is_range(bidxes, self.cur_bsize):
         # mask is on CPU to make assigning easier
         bidxes_ct = BK.input_idx(bidxes, BK.CPU_DEVICE)
         self.scoring_fixed_mask_ct = self.scoring_fixed_mask_ct.index_select(
             0, bidxes_ct)
         self.scoring_mask_ct = self.scoring_mask_ct.index_select(
             0, bidxes_ct)
         self.oracle_mask_ct = self.oracle_mask_ct.index_select(
             0, bidxes_ct)
         # other things are all on target-device (possibly GPU)
         bidxes_device = BK.to_device(bidxes_ct)
         self.enc_repr = self.enc_repr.index_select(0, bidxes_device)
         self.scoring_cache.arange_cache(bidxes_device)
         # oracles
         self.oracle_mask_t = self.oracle_mask_t.index_select(
             0, bidxes_device)
         self.oracle_label_t = self.oracle_label_t.index_select(
             0, bidxes_device)
         # update bsize
         self.update_bsize(new_bsize)