コード例 #1
0
    def get_recurrent_layers_outputs(self, ast_input: ASTInput, combined_input,
                                     hidden, forget_vector):
        hidden = repackage_hidden(
            forget_hidden_partly_lstm_cell(hidden,
                                           forget_vector=forget_vector))
        self.last_k_attention.repackage_and_forget_buffer_partly(forget_vector)

        recurrent_output = []
        layered_attn_output = []
        for i in range(combined_input.size()[0]):
            reinit_dropout = i == 0

            # core recurrent part
            cur_h, cur_c = self.recurrent_cell(combined_input[i],
                                               hidden,
                                               reinit_dropout=reinit_dropout)
            hidden = (cur_h, cur_c)
            recurrent_output.append(cur_h)

            # layered part
            attn_output = self.last_k_attention(ast_input.non_terminals[i],
                                                cur_h)
            layered_attn_output.append(attn_output)

        # combine outputs from different layers
        recurrent_output = torch.stack(recurrent_output, dim=0)
        layered_attn_output = torch.stack(layered_attn_output, dim=0)

        return recurrent_output, hidden, layered_attn_output
コード例 #2
0
    def get_recurrent_output(self, combined_input, ast_input: ASTInput, m_hidden, forget_vector):
        hidden, layered_hidden = m_hidden
        nodes_depth = ast_input.nodes_depth

        # repackage hidden and forgot hidden if program file changed
        hidden = repackage_hidden(forget_hidden_partly_lstm_cell(hidden, forget_vector=forget_vector))
        layered_hidden = LayeredRecurrentUpdateAfter.repackage_and_partly_forget_hidden(
            layered_hidden=layered_hidden,
            forget_vector=forget_vector
        )
        self.last_k_attention.repackage_and_forget_buffer_partly(forget_vector)

        # prepare node depths (store only self.num_tree_layers)
        nodes_depth = torch.clamp(nodes_depth, min=0, max=self.num_tree_layers - 1)

        recurrent_output = []
        attn_output = []
        layered_output = []
        b_h = None
        for i in range(combined_input.size()[0]):
            reinit_dropout = i == 0

            # core recurrent part
            cur_h, cur_c = self.recurrent_cell(combined_input[i], hidden, reinit_dropout=reinit_dropout)
            hidden = (cur_h, cur_c)
            b_h = hidden
            recurrent_output.append(cur_h)

            # attn part
            cur_attn_output = self.last_k_attention(cur_h)
            attn_output.append(cur_attn_output)

            # layered part
            l_h, l_c = self.layered_recurrent(
                combined_input[i],
                nodes_depth[i],
                layered_hidden=layered_hidden,
                reinit_dropout=reinit_dropout
            )

            layered_hidden = LayeredRecurrentUpdateAfter.update_layered_lstm_hidden(
                layered_hidden=layered_hidden,
                node_depths=nodes_depth[i],
                new_value=(l_h, l_c)
            )

            layered_output_coefficients = self.layered_attention(l_h, layered_hidden[0])
            cur_layered_output = calc_attention_combination(layered_output_coefficients, layered_hidden[0])
            layered_output.append(cur_layered_output)  # maybe cat?

        # combine outputs from different layers
        recurrent_output = torch.stack(recurrent_output, dim=0)
        attn_output = torch.stack(attn_output, dim=0)
        layered_output = torch.stack(layered_output, dim=0)

        assert b_h == hidden
        concatenated_output = torch.cat((recurrent_output, attn_output, layered_output), dim=-1)

        return concatenated_output, (hidden, layered_hidden)
コード例 #3
0
    def get_recurrent_output(self, combined_input, ast_input: ASTInput,
                             m_hidden, forget_vector):
        hidden = m_hidden

        hidden = forget_hidden_partly(hidden, forget_vector=forget_vector)
        hidden = repackage_hidden(hidden)

        recurrent_output, new_hidden = self.recurrent_core(
            combined_input, hidden)

        return recurrent_output, new_hidden
コード例 #4
0
 def repackage_and_partly_forget_hidden(layered_hidden,
                                        forget_vector):  # checked
     layered_hidden = forget_hidden_partly_lstm_cell(
         h=layered_hidden, forget_vector=forget_vector.unsqueeze(1))
     return repackage_hidden(layered_hidden)
コード例 #5
0
ファイル: ast_core.py プロジェクト: zerogerc/rnn-autocomplete
 def repackage_and_forget_buffer_partly(self, forget_vector):
     self.buffer = [
         repackage_hidden(b.mul(forget_vector)) for b in self.buffer
     ]
コード例 #6
0
 def repackage_and_forget_buffer_partly(self, forget_vector):
     # self.buffer = self.buffer.mul(forget_vector.unsqueeze(1)) TODO: implement forgetting
     self.buffer = [repackage_hidden(b) for b in self.buffer]
コード例 #7
0
ファイル: model.py プロジェクト: zerogerc/rnn-autocomplete
 def get_recurrent_output(self, input_embedded, hidden, forget_vector):
     hidden = repackage_hidden(forget_hidden_partly(hidden, forget_vector))
     return self.lstm(input_embedded, hidden)