def setup(input_data): """Returns new ASTInput with tensors located on the needed devices.""" return ASTInput( non_terminals=setup_tensor(input_data.non_terminals), terminals=setup_tensor(input_data.terminals), nodes_depth=setup_tensor(input_data.nodes_depth), # no gradients should be computed nodes_depth_target=setup_tensor(input_data.nodes_depth_target) )
def init_hidden(self, batch_size): h = setup_tensor( torch.zeros( (batch_size, self.num_tree_layers, self.single_hidden_size))) c = setup_tensor( torch.zeros( (batch_size, self.num_tree_layers, self.single_hidden_size))) return h, c
def init_hidden(self, batch_size): h = setup_tensor( torch.zeros((self.num_layers, batch_size, self.hidden_size))) if self.model_type == 'lstm': c = setup_tensor( torch.zeros((self.num_layers, batch_size, self.hidden_size))) return h, c else: return h
def init_buffer(self, batch_size): c = 1 if self.is_eval: c = 2 self.buffer = [ setup_tensor(torch.zeros((batch_size, self.hidden_size))) for _ in range(c * self.window_len) ]
def create_lstm_cell_hidden(hidden_size, batch_size): h = setup_tensor(torch.zeros((batch_size, hidden_size))) c = setup_tensor(torch.zeros((batch_size, hidden_size))) return h, c
def init_buffer(self, batch_size): self.buffer = [ setup_tensor(torch.zeros((batch_size, self.hidden_size))) for _ in range(self.window_len) ]
def setup(target_data): """Returns new ASTTarget with tensors located on the needed devices.""" return ASTTarget( non_terminals=setup_tensor(target_data.non_terminals), terminals=setup_tensor(target_data.terminals) )