def get_recurrent_output(self, combined_input, ast_input: ASTInput, m_hidden, forget_vector): hidden = m_hidden hidden = forget_hidden_partly(hidden, forget_vector=forget_vector) hidden = repackage_hidden(hidden) recurrent_output, new_hidden = self.recurrent_core(combined_input, hidden) return recurrent_output, new_hidden
def get_recurrent_layers_outputs(self, ast_input: ASTInput, combined_input, hidden, forget_vector): hidden = repackage_hidden( forget_hidden_partly_lstm_cell(hidden, forget_vector=forget_vector)) self.last_k_attention.repackage_and_forget_buffer_partly(forget_vector) recurrent_output = [] layered_attn_output = [] for i in range(combined_input.size()[0]): reinit_dropout = i == 0 cur_h, cur_c = self.recurrent_cell(combined_input[i], hidden, reinit_dropout=reinit_dropout) hidden = (cur_h, cur_c) recurrent_output.append(cur_h) attn_output = self.last_k_attention(ast_input.non_terminals[i], cur_h) layered_attn_output.append(attn_output) recurrent_output = torch.stack(recurrent_output, dim=0) layered_attn_output = torch.stack(layered_attn_output, dim=0) return recurrent_output, hidden, layered_attn_output
def repackage_and_partly_forget_hidden(layered_hidden, forget_vector): # checked layered_hidden = forget_hidden_partly_lstm_cell( h=layered_hidden, forget_vector=forget_vector.unsqueeze(1)) return repackage_hidden(layered_hidden)
def repackage_and_forget_buffer_partly(self, forget_vector): self.buffer = [ repackage_hidden(b.mul(forget_vector)) for b in self.buffer ]