def forward(self, x): h_tilda_t = torch.zeros(x.shape[0], self.input_dim, self.n_units).cuda() c_tilda_t = torch.zeros(x.shape[0], self.input_dim, self.n_units).cuda() outputs = torch.jit.annotate(List[Tensor], []) for t in range(x.shape[1]): # eq 1 j_tilda_t = torch.tanh( torch.einsum("bij,ijk->bik", h_tilda_t, self.W_j) + torch.einsum("bij,jik->bjk", x[:, t, :].unsqueeze(1), self.U_j) + self.b_j) # eq 5 i_tilda_t = torch.sigmoid( torch.einsum("bij,ijk->bik", h_tilda_t, self.W_i) + torch.einsum("bij,jik->bjk", x[:, t, :].unsqueeze(1), self.U_i) + self.b_i) f_tilda_t = torch.sigmoid( torch.einsum("bij,ijk->bik", h_tilda_t, self.W_f) + torch.einsum("bij,jik->bjk", x[:, t, :].unsqueeze(1), self.U_f) + self.b_f) o_tilda_t = torch.sigmoid( torch.einsum("bij,ijk->bik", h_tilda_t, self.W_o) + torch.einsum("bij,jik->bjk", x[:, t, :].unsqueeze(1), self.U_o) + self.b_o) # eq 6 c_tilda_t = c_tilda_t * f_tilda_t + i_tilda_t * j_tilda_t # eq 7 h_tilda_t = (o_tilda_t * torch.tanh(c_tilda_t)) outputs += [h_tilda_t] outputs = torch.stack(outputs) outputs = outputs.permute(1, 0, 2, 3) # eq 8 alphas = torch.tanh( torch.einsum("btij,ijk->btik", outputs, self.F_alpha_n) + self.F_alpha_n_b) alphas = torch.exp(alphas) alphas = alphas / torch.sum(alphas, dim=1, keepdim=True) g_n = torch.sum(alphas * outputs, dim=1) hg = torch.cat([g_n, h_tilda_t], dim=2) mu = self.Phi(hg) betas = torch.tanh(self.F_beta(hg)) betas = torch.exp(betas) betas = betas / torch.sum(betas, dim=1, keepdim=True) mean = torch.sum(betas * mu, dim=1) return mean, alphas, betas
def forward(self, x, y_prev_t): if self.use_predicted_output: y_prev_t = y_prev_t[0:x.shape[0]] for conv in range(self.n_convs): x = self.convs[conv](x) x = self.conv_to_enc(x) x, h_t_l = self.RHNEncoder(x) # h_T_L.shape = (batch_size, T, n_units_enc, rec_depth) s = torch.zeros(x.shape[0], self.n_units_dec).cuda() for t in range(self.T): s_rep = s.unsqueeze(1) s_rep = s_rep.repeat(1, self.T, 1) d_t = [] for k in range(self.rec_depth): h_t_k = h_t_l[..., k] _ = self.U_k[k](h_t_k) _ = self.T_k[k](s_rep) e_t_k = self.v_k[k](torch.tanh(self.T_k[k](s_rep) + self.U_k[k](h_t_k))) alpha_t_k = torch.softmax(e_t_k, 1) d_t_k = torch.sum(h_t_k * alpha_t_k, dim=1) d_t.append(d_t_k) d_t = torch.cat(d_t, dim=1) if self.use_predicted_output: y_tilda_t, s, y_prev_t = self._last_v2(y_prev_t, d_t, s) else: y_tilda_t, s, _ = self._last_v1(y_prev_t, d_t, s, t) y_t = self.W(s) + self.V(d_t) return y_t, y_prev_t
def eval_epoch_v2(self, data_loader): mse_val = 0 preds = [] true = [] batch_y_h1 = torch.zeros(self.config['batch_size'], 1) for batch_x, _, batch_y in data_loader: batch_x = batch_x.cuda() batch_y = batch_y.cuda() batch_y_h1 = batch_y_h1.cuda() output, batch_y_h1 = self.pt_model(batch_x, batch_y_h1) batch_y_h1 = batch_y_h1.detach() output = output.squeeze(1) preds.append(output.detach().cpu().numpy()) true.append(batch_y.detach().cpu().numpy()) mse_val += self.loss(output, batch_y).item() * batch_x.shape[0] return true, preds, mse_val
def train_epoch_v2(self, data_loader): # using previous predictions as input batch_y_h = torch.zeros(self.config['batch_size'], 1) mse_train = 0 for batch_x, _, batch_y in data_loader: batch_x = batch_x.cuda() batch_y = batch_y.cuda() batch_y_h = batch_y_h.cuda() self.opt.zero_grad() y_pred, batch_y_h = self.pt_model(batch_x, batch_y_h) batch_y_h = batch_y_h.detach() y_pred = y_pred.squeeze(1) l = self.loss(y_pred, batch_y) l.backward() mse_train += l.item() * batch_x.shape[0] self.opt.step() return mse_train
def forward(self, x): s = torch.zeros(x.shape[0], self.n_units).cuda() preds = [] highway_states = [] for t in range(x.shape[1]): if self.use_batch_norm: x_inp = self.bn_x(x[:, t, :]) s = self.bn_s(s) else: x_inp = x[:, t, :] s, all_s = self.RHNCell(x_inp, s) preds.append(s) highway_states.append(all_s) preds = torch.stack(preds) preds = preds.permute(1, 0, 2) highway_states = torch.stack(highway_states) highway_states = highway_states.permute(2, 0, 3, 1) out = preds return out, highway_states