def prepare_training_rop(self): mconf = self.conf embed_rop = RefreshOptions(hdrop=mconf.drop_embed, dropmd=mconf.dropmd_embed, fix_drop=mconf.fix_drop) other_rop = RefreshOptions(hdrop=mconf.drop_hidden, idrop=mconf.idrop_rnn, gdrop=mconf.gdrop_rnn, fix_drop=mconf.fix_drop) return embed_rop, other_rop
def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None): conf = self.conf # ----- # special mode if conf.aug_word2 and conf.aug_word2_aug_encoder: _rop = RefreshOptions(training=False) # special feature-mode!! self.embedder.refresh(_rop) self.encoder.refresh(_rop) # ----- emb_t, mask_t = self.embedder(cur_input_map) rel_dist = cur_input_map.get("rel_dist", None) if rel_dist is not None: rel_dist = BK.input_idx(rel_dist) if conf.enc_choice == "vrec": enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) elif conf.enc_choice == "original": # todo(note): change back to arr for back compatibility assert rel_dist is None, "Original encoder does not support rel_dist" enc_t = self.encoder(emb_t, BK.get_value(mask_t)) cache, enc_loss = None, None else: raise NotImplementedError() # another encoder based on attn final_enc_t = self.rpreper(emb_t, enc_t, cache) # [*, slen, D] => final encoder output if conf.aug_word2: emb2_t = self.aug_word2(insts) if conf.aug_word2_aug_encoder: # simply add them all together, detach orig-enc as features stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach() features = self.aug_mixturer(stack_hidden_t) aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features)) final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) else: final_enc_t = (final_enc_t + emb2_t) # otherwise, simply adding return emb_t, mask_t, final_enc_t, cache, enc_loss
def refresh_batch(self, training: bool): # refresh graph # todo(warn): make sure to remember clear this one nn_refresh() # refresh nodes if not training: if not self.previous_refresh_training: # todo(+1): currently no need to refresh testing mode multiple times return self.previous_refresh_training = False rop = RefreshOptions(training=False) # default no dropout else: conf = self.conf rop = RefreshOptions(training=True, hdrop=conf.drop_hidden, idrop=conf.idrop_rnn, gdrop=conf.gdrop_rnn, fix_drop=conf.fix_drop) self.previous_refresh_training = True for node in self.nodes.values(): node.refresh(rop) for node in self.components.values(): node.refresh(rop)
def refresh_batch(self, training: bool): # refresh graph # todo(warn): make sure to remember clear this one nn_refresh() # refresh nodes if not training: if not self.previous_refresh_training: # todo(+1): currently no need to refresh testing mode multiple times return self.previous_refresh_training = False embed_rop = other_rop = RefreshOptions(training=False) # default no dropout else: embed_rop, other_rop = self.bter.prepare_training_rop() # todo(warn): once-bug, don't forget this one!! self.previous_refresh_training = True # manually refresh self.bter.special_refresh(embed_rop, other_rop) for node in self.decoders: node.refresh(other_rop)