Пример #1
0
 def prepare_training_rop(self):
     mconf = self.conf
     embed_rop = RefreshOptions(hdrop=mconf.drop_embed,
                                dropmd=mconf.dropmd_embed,
                                fix_drop=mconf.fix_drop)
     other_rop = RefreshOptions(hdrop=mconf.drop_hidden,
                                idrop=mconf.idrop_rnn,
                                gdrop=mconf.gdrop_rnn,
                                fix_drop=mconf.fix_drop)
     return embed_rop, other_rop
Пример #2
0
 def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None):
     conf = self.conf
     # -----
     # special mode
     if conf.aug_word2 and conf.aug_word2_aug_encoder:
         _rop = RefreshOptions(training=False)  # special feature-mode!!
         self.embedder.refresh(_rop)
         self.encoder.refresh(_rop)
     # -----
     emb_t, mask_t = self.embedder(cur_input_map)
     rel_dist = cur_input_map.get("rel_dist", None)
     if rel_dist is not None:
         rel_dist = BK.input_idx(rel_dist)
     if conf.enc_choice == "vrec":
         enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss)
     elif conf.enc_choice == "original":  # todo(note): change back to arr for back compatibility
         assert rel_dist is None, "Original encoder does not support rel_dist"
         enc_t = self.encoder(emb_t, BK.get_value(mask_t))
         cache, enc_loss = None, None
     else:
         raise NotImplementedError()
     # another encoder based on attn
     final_enc_t = self.rpreper(emb_t, enc_t, cache)  # [*, slen, D] => final encoder output
     if conf.aug_word2:
         emb2_t = self.aug_word2(insts)
         if conf.aug_word2_aug_encoder:
             # simply add them all together, detach orig-enc as features
             stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach()
             features = self.aug_mixturer(stack_hidden_t)
             aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features))
             final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t,
                                                             rel_dist=rel_dist, collect_loss=collect_loss)
         else:
             final_enc_t = (final_enc_t + emb2_t)  # otherwise, simply adding
     return emb_t, mask_t, final_enc_t, cache, enc_loss
Пример #3
0
 def refresh_batch(self, training: bool):
     # refresh graph
     # todo(warn): make sure to remember clear this one
     nn_refresh()
     # refresh nodes
     if not training:
         if not self.previous_refresh_training:
             # todo(+1): currently no need to refresh testing mode multiple times
             return
         self.previous_refresh_training = False
         rop = RefreshOptions(training=False)  # default no dropout
     else:
         conf = self.conf
         rop = RefreshOptions(training=True,
                              hdrop=conf.drop_hidden,
                              idrop=conf.idrop_rnn,
                              gdrop=conf.gdrop_rnn,
                              fix_drop=conf.fix_drop)
         self.previous_refresh_training = True
     for node in self.nodes.values():
         node.refresh(rop)
     for node in self.components.values():
         node.refresh(rop)
Пример #4
0
 def refresh_batch(self, training: bool):
     # refresh graph
     # todo(warn): make sure to remember clear this one
     nn_refresh()
     # refresh nodes
     if not training:
         if not self.previous_refresh_training:
             # todo(+1): currently no need to refresh testing mode multiple times
             return
         self.previous_refresh_training = False
         embed_rop = other_rop = RefreshOptions(training=False)  # default no dropout
     else:
         embed_rop, other_rop = self.bter.prepare_training_rop()
         # todo(warn): once-bug, don't forget this one!!
         self.previous_refresh_training = True
     # manually refresh
     self.bter.special_refresh(embed_rop, other_rop)
     for node in self.decoders:
         node.refresh(other_rop)