Exemplo n.º 1
0
 def get_init_state(self, batsize, device=torch.device("cpu")):
     state = State()
     x = torch.ones(batsize, self.numlayers, self.hdim, device=device)
     state.h = torch.zeros_like(x)
     state.c = torch.zeros_like(x)
     state.levels = torch.zeros_like(x[:, 0, 0])
     return state
Exemplo n.º 2
0
 def get_init_state(self, batsize, device=torch.device("cpu")):
     state = State()
     x = torch.ones(batsize, self.numlayers, self.hdim, device=device)
     state.h = torch.zeros_like(x)
     state.c = torch.zeros_like(x)
     state.h_dropout = self.dropout_rec(torch.ones_like(x)).clamp(0, 1)
     state.c_dropout = self.dropout_rec(torch.ones_like(x)).clamp(0, 1)
     return state
Exemplo n.º 3
0
 def get_init_state(self, batsize, device=torch.device("cpu")):
     main_state = self.main_lstm.get_init_state(batsize, device)
     reduce_state = self.reduce_lstm.get_init_state(batsize, device)
     state = State()
     state.h = main_state.h
     state.c = main_state.c
     state.stack = np.array(range(batsize), dtype="object")
     for i in range(batsize):
         state.stack[i] = []
         state.stack[i].append((main_state[i:i+1], reduce_state[i:i+1]))
     return state
Exemplo n.º 4
0
 def forward(self, inp:torch.Tensor, state:State):
     """
     :param inp:     (batsize, indim)
     :param state:   State with .h, .c of shape (numlayers, batsize, hdim)
     :return:
     """
     x = inp
     _x = self.dropout(x)
     h_nm1 = ((state.h * state.h_dropout) if self.dropout_rec.p > 0 else state.h).transpose(0, 1)
     c_nm1 = ((state.c * state.c_dropout) if self.dropout_rec.p > 0 else state.c).transpose(0, 1)
     out, (h_n, c_n) = self.cell(_x[:, None, :], (h_nm1.contiguous(), c_nm1.contiguous()))
     out = out[:, 0, :]
     state.h = h_n.transpose(0, 1)
     state.c = c_n.transpose(0, 1)
     return out, state