コード例 #1
0
ファイル: base_idec.py プロジェクト: zzsfornlp/zmsp
 def forward_input(self, med: ZMediator, detach_scale: float, **kwargs):
     # get it
     if self.do_dsel:
         _dsel = self.dsel
         input_t0 = med.get_enc_cache_val(
             "hid", signature=_dsel.signature, function=(lambda x: _dsel.forward(x, med.ibatch.seq_info)))  # [*, ??, D]
     else:
         # input_t0 = med.get_enc_cache_val("hid", no_cache=True)  # [*, ??, D], note: no need for caching!
         input_t0 = med.get_enc_cache_val("hid")  # [*, ??, D]
     mask_t = med.get_mask(self.do_dsel)  # [*, ??]
     # extra processing?
     if self.do_seq_pool:
         input_t = self.seq_pool_f(input_t0)  # [*, D]
     elif self.do_seq_sel:
         _arange_t = BK.arange_idx(BK.get_shape(input_t0, 0))  # [*]
         _idx_t = med.get_cache(self.seq_sel_key)  # [*]
         input_t = input_t0[_arange_t, _idx_t]  # [*, D]
     else:
         input_t = input_t0
     # detach?
     ret_t = BK.go_detach(input_t, detach_scale, self.is_training())
     return ret_t, mask_t  # [*, (??), D], [*, ??]
コード例 #2
0
 def _go_detach(self, x):
     return BK.go_detach(x, self.detach_scale, self.is_training())