def forward(self, x, state):
     device = list(self.parameters())[0].device
     x_vector = torch.stack(
         [one_hot(x[:, i], self.inputs, device) for i in range(x.shape[1])])
     lstm_output, new_sate = self.lstm_layer(x_vector, state)
     output_vec = self.linear_layer(lstm_output.view(-1, self.hiddens))
     return output_vec, new_sate
 def forward(self, x, state):
     paras = list(self.parameters())
     x_features = torch.stack([one_hot(x[:, i], self.class_num) for i in range(x.shape[1])])
     x_features = x_features.to(paras[0].device)
     if state is not None:
         state = state.to(paras[0].device)
     rnn_y, new_state = self.rnn_layer(x_features, state)
     result = self.linear_layer(rnn_y.view(-1, self.num_hiddens))
     return result, new_state
示例#3
0
 def forward(self, x, state):
     param = list(self.net.parameters())
     device = param[0].device
     if state is None:
         state = torch.zeros(x.shape[0], self.net.hiddens, dtype=torch.float32, device=device)
     else:
         state = state.to(device)
     x_vectors = [one_hot(x[:, i], self.net.inputs, device) for i in range(x.shape[1])]
     result, new_state = self.net((x_vectors, state))
     return result, new_state
 def forward(self, x, state):
     params = list(self.parameters())
     device = params[0].device
     if state is not None:
         state = state.to(device)
     x_vector = torch.stack(
         [one_hot(x[:, i], self.inputs, device) for i in range(x.shape[1])])
     gru_result, new_state = self.gru_layer(x_vector, state)
     result = self.linear_layer(gru_result.view(-1, self.hiddens))
     return result, new_state
 def forward(self, x, state):
     params = list(self.net.parameters())
     device = params[0].device
     if state == None:
         h = self.init_state(x.shape[0], self.net.hiddens, device)
         c = self.init_state(x.shape[0], self.net.hiddens, device)
     else:
         h, c = state
         h = h.to(device)
         c = c.to(device)
     x_vector = [one_hot(x[:, i], self.net.outputs, device) for i in range(x.shape[1])]
     y, h, c = self.net((x_vector, h, c))
     return y, (h, c)
示例#6
0
    def predict(self, prexs_index, predict_char_num):
        param = list(self.net.parameters())
        device = param[0].device
        outputs = [prexs_index[0]]
        state = torch.zeros(1, self.net.hiddens, dtype=torch.float32, device=device)

        for step in range(predict_char_num + len(prexs_index) - 1):
            x = torch.tensor(outputs[-1], dtype=torch.int64).view(1,1)
            x_vector = [one_hot(x[:, i], self.net.inputs, device) for i in range(x.shape[1])]
            inputs = (x_vector, state)
            result, state = self.net(inputs)
            if step < len(prexs_index) - 1:
                outputs.append(prexs_index[step + 1])
            else:
                result_idex = result.argmax(dim=1).view(-1).cpu().item()
                outputs.append(result_idex)
        return outputs
示例#7
0
 def to_one_hot(self, x, num_step, class_num):
     return [one_hot(x[:, i], class_num) for i in range(num_step)]