Beispiel #1
0
    def forward(self, x):
        h_tilda_t = torch.zeros(x.shape[0], self.input_dim,
                                self.n_units).cuda()
        c_tilda_t = torch.zeros(x.shape[0], self.input_dim,
                                self.n_units).cuda()
        outputs = torch.jit.annotate(List[Tensor], [])
        for t in range(x.shape[1]):
            # eq 1
            j_tilda_t = torch.tanh(
                torch.einsum("bij,ijk->bik", h_tilda_t, self.W_j) +
                torch.einsum("bij,jik->bjk", x[:,
                                               t, :].unsqueeze(1), self.U_j) +
                self.b_j)
            # eq 5
            i_tilda_t = torch.sigmoid(
                torch.einsum("bij,ijk->bik", h_tilda_t, self.W_i) +
                torch.einsum("bij,jik->bjk", x[:,
                                               t, :].unsqueeze(1), self.U_i) +
                self.b_i)
            f_tilda_t = torch.sigmoid(
                torch.einsum("bij,ijk->bik", h_tilda_t, self.W_f) +
                torch.einsum("bij,jik->bjk", x[:,
                                               t, :].unsqueeze(1), self.U_f) +
                self.b_f)
            o_tilda_t = torch.sigmoid(
                torch.einsum("bij,ijk->bik", h_tilda_t, self.W_o) +
                torch.einsum("bij,jik->bjk", x[:,
                                               t, :].unsqueeze(1), self.U_o) +
                self.b_o)
            # eq 6
            c_tilda_t = c_tilda_t * f_tilda_t + i_tilda_t * j_tilda_t
            # eq 7
            h_tilda_t = (o_tilda_t * torch.tanh(c_tilda_t))
            outputs += [h_tilda_t]
        outputs = torch.stack(outputs)
        outputs = outputs.permute(1, 0, 2, 3)
        # eq 8
        alphas = torch.tanh(
            torch.einsum("btij,ijk->btik", outputs, self.F_alpha_n) +
            self.F_alpha_n_b)
        alphas = torch.exp(alphas)
        alphas = alphas / torch.sum(alphas, dim=1, keepdim=True)
        g_n = torch.sum(alphas * outputs, dim=1)
        hg = torch.cat([g_n, h_tilda_t], dim=2)
        mu = self.Phi(hg)
        betas = torch.tanh(self.F_beta(hg))
        betas = torch.exp(betas)
        betas = betas / torch.sum(betas, dim=1, keepdim=True)
        mean = torch.sum(betas * mu, dim=1)

        return mean, alphas, betas
Beispiel #2
0
    def forward(self, x, y_prev_t):
        if self.use_predicted_output:
            y_prev_t = y_prev_t[0:x.shape[0]]
        for conv in range(self.n_convs):
            x = self.convs[conv](x)
        x = self.conv_to_enc(x)
        x, h_t_l = self.RHNEncoder(x)  # h_T_L.shape = (batch_size, T, n_units_enc, rec_depth)
        s = torch.zeros(x.shape[0], self.n_units_dec).cuda()
        for t in range(self.T):
            s_rep = s.unsqueeze(1)
            s_rep = s_rep.repeat(1, self.T, 1)
            d_t = []
            for k in range(self.rec_depth):
                h_t_k = h_t_l[..., k]
                _ = self.U_k[k](h_t_k)
                _ = self.T_k[k](s_rep)
                e_t_k = self.v_k[k](torch.tanh(self.T_k[k](s_rep) + self.U_k[k](h_t_k)))
                alpha_t_k = torch.softmax(e_t_k, 1)
                d_t_k = torch.sum(h_t_k * alpha_t_k, dim=1)
                d_t.append(d_t_k)
            d_t = torch.cat(d_t, dim=1)
            if self.use_predicted_output:
                y_tilda_t, s, y_prev_t = self._last_v2(y_prev_t, d_t, s)
            else:
                y_tilda_t, s, _ = self._last_v1(y_prev_t, d_t, s, t)

        y_t = self.W(s) + self.V(d_t)
        return y_t, y_prev_t
Beispiel #3
0
    def eval_epoch_v2(self, data_loader):
        mse_val = 0
        preds = []
        true = []
        batch_y_h1 = torch.zeros(self.config['batch_size'], 1)
        for batch_x, _, batch_y in data_loader:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
            batch_y_h1 = batch_y_h1.cuda()
            output, batch_y_h1 = self.pt_model(batch_x, batch_y_h1)
            batch_y_h1 = batch_y_h1.detach()
            output = output.squeeze(1)
            preds.append(output.detach().cpu().numpy())
            true.append(batch_y.detach().cpu().numpy())
            mse_val += self.loss(output, batch_y).item() * batch_x.shape[0]

        return true, preds, mse_val
Beispiel #4
0
    def train_epoch_v2(self, data_loader):
        # using previous predictions as input
        batch_y_h = torch.zeros(self.config['batch_size'], 1)
        mse_train = 0
        for batch_x, _, batch_y in data_loader:
            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()
            batch_y_h = batch_y_h.cuda()
            self.opt.zero_grad()
            y_pred, batch_y_h = self.pt_model(batch_x, batch_y_h)
            batch_y_h = batch_y_h.detach()
            y_pred = y_pred.squeeze(1)
            l = self.loss(y_pred, batch_y)
            l.backward()
            mse_train += l.item() * batch_x.shape[0]
            self.opt.step()

        return mse_train
Beispiel #5
0
    def forward(self, x):
        s = torch.zeros(x.shape[0], self.n_units).cuda()
        preds = []
        highway_states = []
        for t in range(x.shape[1]):
            if self.use_batch_norm:
                x_inp = self.bn_x(x[:, t, :])
                s = self.bn_s(s)
            else:
                x_inp = x[:, t, :]
            s, all_s = self.RHNCell(x_inp, s)
            preds.append(s)
            highway_states.append(all_s)
        preds = torch.stack(preds)
        preds = preds.permute(1, 0, 2)
        highway_states = torch.stack(highway_states)
        highway_states = highway_states.permute(2, 0, 3, 1)
        out = preds

        return out, highway_states