Esempio n. 1
0
    def backward(ctx, logl_grad):
        # Recovering saved tensors from forward
        tree = ctx.saved_input
        beta, t_beta, A, B, Pi, SP = ctx.saved_tensors

        # Getting model info
        C, n_gen, device = A.size(0), A.size(-1), A.device

        # Creating parameter gradient tensors
        A_grad, B_grad, Pi_grad, SP_grad = torch.zeros_like(
            A), torch.zeros_like(B), torch.zeros_like(Pi), torch.zeros_like(SP)

        eps = torch.zeros((tree['dim'], C, n_gen), device=device)

        roots = tree['levels'][0][0].unique(sorted=False)
        eps[roots] = beta[roots]
        for l in tree['levels']:
            # Computing eps_{u, ch(u)}(i, j)
            pos_ch = tree['pos'][l[1]]
            SP_ch = SP[pos_ch]
            A_ch = A[:, :, pos_ch].permute(2, 0, 1, 3)

            trans_ch = SP_ch.unsqueeze(1).unsqueeze(1) * A_ch
            eps_pa = eps[l[0]].unsqueeze(2)
            beta_ch = beta[l[1]].unsqueeze(1)
            t_beta_pa = t_beta[l[0]].unsqueeze(2)

            # Computing eps_{ch(u)}
            eps_joint = (eps_pa * trans_ch * beta_ch) / t_beta_pa
            eps_ch = eps_joint.sum(1)
            eps[l[1]] = eps_ch

            local_grad = logl_grad[tree['batch'][l[1]]]
            # Accumulating gradient in grad_A and grad_SP
            SP_grad = scatter((eps_ch.sum(1) - SP_ch) * local_grad,
                              index=pos_ch,
                              dim=0,
                              out=SP_grad)
            local_grad = local_grad.unsqueeze(0).unsqueeze(0)
            A_grad = scatter(
                (eps_joint - A_ch * eps_ch.unsqueeze(1)).permute(1, 2, 0, 3) *
                local_grad,
                index=pos_ch,
                dim=2,
                out=A_grad)

        eps_leaves = eps[tree['leaves']]
        local_grad = logl_grad[tree['batch'][tree['leaves']]].unsqueeze(1)
        Pi_grad = ((eps_leaves - Pi.unsqueeze(0)) * local_grad).sum(0)

        eps_B = eps.permute(1, 0, 2)
        B_grad = scatter(torch.ones_like(eps_B),
                         index=tree['x'],
                         dim=1,
                         out=B_grad)
        B_grad -= B * tree['dim']
        local_grad = logl_grad[tree['batch']].unsqueeze(0)
        B_grad *= (eps_B * local_grad).sum(1, keepdim=True)

        return None, A_grad, B_grad, Pi_grad, SP_grad
Esempio n. 2
0
    def forward(ctx, x, tree, lambda_A, lambda_B, lambda_Pi):
        # Softmax Reparameterization
        sm_A, sm_B, sm_Pi = [], [], []
        for i in range(lambda_A.size(-1)):
            sm_A.append(F.softmax(lambda_A[:, :, i], dim=0))
            sm_B.append(F.softmax(lambda_B[:, :, i], dim=1))
            sm_Pi.append(F.softmax(lambda_Pi[:, i], dim=0))

        A, B, Pi = torch.stack(sm_A, dim=-1), torch.stack(sm_B, dim=-1), torch.stack(sm_Pi, dim=-1)

        # Getting model info
        C, n_gen, device = A.size(0), A.size(-1), A.device

        # Preliminary Downward recursion: init
        prior = torch.zeros((tree['dim'], C, n_gen), device=device)

        # Preliminary Downward recursion: base case
        prior[tree['roots']] = Pi

        # Preliminary Downward recursion
        for l in tree['levels']:
            prior_pa = prior[l[0]].unsqueeze(1)
            prior_l = (A.unsqueeze(0) * prior_pa).sum(2)
            prior[l[1]] = prior_l
        
        # Upward recursion: init
        beta = prior * B[:, x[tree['inv_map']]].permute(1, 0, 2)
        t_beta = torch.zeros((tree['dim'], C, n_gen), device=device)
        log_likelihood = torch.zeros((tree['dim'], n_gen), device=device)

        # Upward Recursion: base case
        beta_leaves_unnorm = beta[tree['leaves']]
        nu = beta_leaves_unnorm.sum(dim=1)

        beta[tree['leaves']] = beta_leaves_unnorm / nu.unsqueeze(1)
        log_likelihood[tree['leaves']] = nu.log()

        # Upward Recursion
        for l in reversed(tree['levels']):
            # Computing beta_uv children
            beta_ch = beta[l[1]].unsqueeze(2)
            prior_l = prior[l[1]].unsqueeze(2)
            beta_uv = (A.unsqueeze(0) * beta_ch / prior_l).sum(1)
            t_beta[l[1]] = beta_uv
            
            # Computing beta on level
            pa_idx = l[0].unique(sorted=False)
            prev_beta = beta[pa_idx]
            beta = scatter(src=beta_uv, index=l[0], dim=0, out=beta, reduce='mul')
            beta_l_unnorm = prev_beta * beta[pa_idx]
            nu = beta_l_unnorm.sum(1)

            beta[pa_idx] = beta_l_unnorm / nu.unsqueeze(1)
            log_likelihood[pa_idx] = nu.log()

        ctx.saved_input = x, tree
        ctx.save_for_backward(prior, beta, t_beta, A, B, Pi)

        return scatter(log_likelihood, tree['trees_ind'], dim=0)
Esempio n. 3
0
    def forward(ctx, tree, lambda_A, lambda_B, lambda_Pi, lambda_SP):
        # Softmax reparameterization
        sm_A, sm_B, sm_Pi, sm_SP = [], [], [], []
        for i in range(lambda_A.size(-1)):
            sm_A.append(
                torch.cat([
                    F.softmax(lambda_A[:, :, j, i], dim=0).unsqueeze(2)
                    for j in range(lambda_A.size(2))
                ],
                          dim=2))
            sm_B.append(F.softmax(lambda_B[:, :, i], dim=1))
            sm_Pi.append(F.softmax(lambda_Pi[:, i], dim=0))
            sm_SP.append(F.softmax(lambda_SP[:, i], dim=0))

        A, B, Pi, SP = torch.stack(sm_A, dim=-1), torch.stack(
            sm_B, dim=-1), torch.stack(sm_Pi, dim=-1), torch.stack(sm_SP,
                                                                   dim=-1)

        # Getting model info
        C, n_gen, device = A.size(0), A.size(-1), A.device

        # Upward recursion: init
        beta = torch.zeros((tree['dim'], C, n_gen), device=device)
        t_beta = torch.zeros((tree['dim'], C, n_gen), device=device)
        log_likelihood = torch.zeros((tree['dim'], n_gen), device=device)

        # Upward recursion: base case
        B_leaves = B[:, tree['x'][tree['leaves']]]
        beta_leaves = (Pi.unsqueeze(1) * B_leaves).permute(1, 0, 2)
        nu = beta_leaves.sum(dim=1)

        beta[tree['leaves']] = beta_leaves / nu.unsqueeze(1)
        log_likelihood[tree['leaves']] = nu.log()

        # Upward recursion
        for l in reversed(tree['levels']):
            # Computing beta_uv children = (A_ch @ beta_ch) / prior_pa
            pos_ch = tree['pos'][l[1]]
            SP_ch = SP[pos_ch].unsqueeze(1).unsqueeze(2)
            A_ch = A[:, :, pos_ch].permute(2, 0, 1, 3)
            beta_ch = beta[l[1]].unsqueeze(1)

            t_beta_ch = (SP_ch * A_ch * beta_ch).sum(2)
            t_beta = scatter(src=t_beta_ch, index=l[0], dim=0, out=t_beta)

            u_idx = l[0].unique(sorted=False)
            B_l = B[:, tree['x'][u_idx]].permute(1, 0, 2)
            beta_l = t_beta[u_idx] * B_l
            nu = beta_l.sum(dim=1)

            beta[u_idx] = beta_l / nu.unsqueeze(1)
            log_likelihood[u_idx] = nu.log()

        ctx.saved_input = tree
        ctx.save_for_backward(beta, t_beta, A, B, Pi, SP)

        return scatter(log_likelihood, tree['batch'], dim=0)
Esempio n. 4
0
    def backward(ctx, logl_grad):
        # Recovering saved tensors from forward
        tree = ctx.saved_input
        prior, beta, t_beta, A, B, Pi = ctx.saved_tensors

        # Getting model info
        C, n_gen, device = A.size(0), A.size(-1), A.device

        # Creating parameter gradient tensors
        A_grad, B_grad = torch.zeros_like(A), torch.zeros_like(B)

        eps = torch.zeros((tree['dim'], C, n_gen), device=device)

        roots = tree['levels'][0][0].unique(sorted=False)
        eps_roots = beta[roots]
        eps[roots] = eps_roots
        for l in tree['levels']:
            # Computing eps_{u, pa(u)}(i, j)
            eps_pa = eps[l[0]].unsqueeze(1)
            pos_ch = tree['pos'][l[1]]
            A_ch = A[:, :, pos_ch].permute(2, 0, 1, 3)
            beta_ch = beta[l[1]].unsqueeze(2)
            eps_trans_pa = A_ch * eps_pa

            t_beta_ch = t_beta[l[1]].unsqueeze(1)
            prior_ch = prior[l[1]].unsqueeze(2)

            eps_joint = (beta_ch * eps_trans_pa) / (prior_ch * t_beta_ch)

            # Computing eps_u(i)
            eps_ch = eps_joint.sum(2)
            eps[l[1]] = eps_ch

            local_grad = logl_grad[tree['batch'][l[1]]].unsqueeze(0).unsqueeze(
                0)
            # Accumulating gradient in grad_A and grad_SP
            A_grad = scatter(
                (eps_joint - eps_trans_pa).permute(1, 2, 0, 3) * local_grad,
                index=pos_ch,
                dim=2,
                out=A_grad)

        eps_B = eps.permute(1, 0, 2)
        B_grad = scatter(torch.ones_like(eps_B),
                         index=tree['x'],
                         dim=1,
                         out=B_grad)
        B_grad -= B * tree['dim']
        local_grad = logl_grad[tree['batch']].unsqueeze(0)
        B_grad *= (eps_B * local_grad).sum(1, keepdim=True)

        local_grad = logl_grad[tree['batch'][roots]].unsqueeze(1)
        Pi_grad = ((eps_roots - Pi.unsqueeze(0)) * local_grad).sum(0)

        return None, A_grad, B_grad, Pi_grad
Esempio n. 5
0
    def forward(ctx, x, tree, lambda_A, lambda_B, lambda_Pi):
        # Softmax Reparameterization
        sm_A, sm_B, sm_Pi = [], [], []
        for i in range(lambda_A.size(-1)):
            sm_A.append(F.softmax(lambda_A[:, :, i], dim=0))
            sm_B.append(F.softmax(lambda_B[:, :, i], dim=1))
            sm_Pi.append(F.softmax(lambda_Pi[:, i], dim=0))

        A, B, Pi = torch.stack(sm_A, dim=-1), torch.stack(
            sm_B, dim=-1), torch.stack(sm_Pi, dim=-1)

        # Getting model info
        C, n_gen, device = A.size(0), A.size(-1), A.device

        # Upward recursion: init
        beta = torch.zeros((tree['dim'], C, n_gen), device=device)
        t_beta = torch.zeros((tree['dim'], C, n_gen), device=device)
        log_likelihood = torch.zeros((tree['dim'], n_gen), device=device)

        # Upward Recursion: base case
        Pi_leaves = Pi.unsqueeze(0)
        leaves_idx = tree['inv_map'][tree['leaves']]
        B_leaves = B[:, x[leaves_idx]].permute(1, 0, 2)
        beta_leaves = Pi_leaves * B_leaves
        nu = beta_leaves.sum(dim=1)

        beta[tree['leaves']] = beta_leaves / nu.unsqueeze(1)
        log_likelihood[tree['leaves']] = nu.log()

        # Upward Recursion
        for l in reversed(tree['levels']):
            # Computing unnormalized beta_uv children = A_ch @ beta_ch
            beta_ch = beta[l[1]]
            t_beta_ch = (A.unsqueeze(0) * beta_ch.unsqueeze(1)).sum(2)
            t_beta = scatter(src=t_beta_ch,
                             index=l[0],
                             dim=0,
                             out=t_beta,
                             reduce='mean')

            u_idx = l[0].unique(sorted=False)
            B_l = B[:, x[tree['inv_map'][u_idx]]].permute(1, 0, 2)
            beta_l = B_l * t_beta[u_idx]
            nu = beta_l.sum(dim=1)

            beta[u_idx] = beta_l / nu.unsqueeze(1)
            log_likelihood[u_idx] = nu.log()

        ctx.saved_input = x, tree
        ctx.save_for_backward(beta, t_beta, A, B, Pi)

        return scatter(log_likelihood, tree['trees_ind'], dim=0)
Esempio n. 6
0
    def backward(ctx, likelihood, posterior_i):
        x, edge_index = ctx.saved_input
        posterior_il, posterior_i, Q, B = ctx.saved_tensors

        post_neigh = scatter(posterior_i[edge_index[1]], edge_index[0], dim=0, reduce='mean').unsqueeze(1)
        Q_grad = (posterior_il - Q * post_neigh).sum(0)
        
        post_nodes = posterior_i.permute(1, 0, 2)
        B_grad = scatter(post_nodes - post_nodes * B[:, x],
                         index=x,
                         dim=1,
                         out=torch.zeros_like(B, device=B.device))

        return None, None, None, Q_grad, B_grad
Esempio n. 7
0
    def backward(ctx, log_likelihood):
        # Recovering saved tensors from forward
        x, tree = ctx.saved_input
        beta, t_beta, A, B, Pi = ctx.saved_tensors

        # Getting model info
        C, n_gen, device = A.size(0), A.size(-1), A.device

        # Creating parameter gradient tensors
        A_grad, B_grad, Pi_grad = torch.zeros_like(A), torch.zeros_like(
            B), torch.zeros_like(Pi)

        # Downward recursion: init
        eps = torch.zeros((tree['dim'], C, n_gen), device=device)
        out_deg = torch.zeros(tree['dim'], device=device)

        # Downward recursion: base case
        eps[tree['roots']] = beta[tree['roots']]

        # Downward recursion
        for l in tree['levels']:
            # Computing eps_{u, ch_i(u)}(i, j)
            out_deg = scatter(torch.ones_like(l[1],
                                              dtype=out_deg.dtype,
                                              device=device),
                              dim=0,
                              index=l[0],
                              out=out_deg)
            t_beta_pa = t_beta[l[0]].unsqueeze(2)
            eps_pa = eps[l[0]].unsqueeze(2)
            beta_ch = beta[l[1]].unsqueeze(1)

            eps_joint = (eps_pa * A.unsqueeze(0) * beta_ch) / (
                t_beta_pa * out_deg[l[0]].view(-1, 1, 1, 1))

            eps_ch = eps_joint.sum(1)
            eps[l[1]] = eps_joint.sum(1)
            A_grad += (eps_joint - A.unsqueeze(0) * eps_ch.unsqueeze(1)).sum(0)

        eps_leaves = eps[tree['leaves']]
        Pi_grad = eps_leaves.sum(0) - tree['leaves'].size(0) * Pi

        eps_nodes = eps.permute(1, 0, 2)
        x_trees = x[tree['inv_map']]
        B_grad = scatter(eps_nodes - eps_nodes * B[:, x_trees],
                         index=x_trees,
                         dim=1,
                         out=B_grad)

        return None, None, A_grad, B_grad, Pi_grad
Esempio n. 8
0
    def forward(self, x, prev_h, edge_index, pos):
        Q_neigh, B = self._softmax_reparameterization()
        
        prev_h_neigh = prev_h[edge_index[1]].unsqueeze(1)
        trans_neigh = Q_neigh[:, :, pos].permute(2, 0, 1, 3)
        
        B_nodes = B[:, x[edge_index[0]]].permute(1, 0, 2).unsqueeze(2)   # edges x C x 1 x n_gen
        unnorm_posterior = B_nodes * trans_neigh * prev_h_neigh # edges x C x C x n_gen
        likelihood = scatter(unnorm_posterior.sum([1, 2], keepdim=True), edge_index[0], dim=0, reduce='mean')
        
        posterior_il = (unnorm_posterior / (likelihood[edge_index[0]] + 1e-16)).detach() # edges x C x C x n_gen
        posterior_i = scatter(posterior_il.sum(2), index=edge_index[0], dim=0).detach() # nodes x C x n_gen
        if self.training and not self.frozen:
            B_nodes = B[:, x].permute(1, 0, 2)  # nodes x C x n_gen, necessary for backpropagating in the new, detached graph
            exp_likelihood = (posterior_il * trans_neigh.log()).sum() + (posterior_i * B_nodes.log()).sum()
            (-exp_likelihood).backward()

        likelihood = likelihood.log().squeeze()
        return likelihood, posterior_i
Esempio n. 9
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        log_likelihood = self.cgmm(x, edge_index)
        log_likelihood = scatter(log_likelihood, batch, dim=0)

        c_neurons = (log_likelihood @ self.contrastive).tanh()
        to_out, _ = self.gru(c_neurons, self.h0.repeat(1, batch.max() + 1, 1))

        out = self.output(to_out.flatten(start_dim=-2))

        return out
Esempio n. 10
0
    def backward(ctx, logl_grad, posterior):
        x = ctx.saved_input
        posterior, B, Pi = ctx.saved_tensors

        post_nodes = posterior.permute(1, 0, 2)
        B_grad = scatter(post_nodes - post_nodes * B[:, x],
                         index=x,
                         dim=1,
                         out=torch.zeros_like(B, device=B.device))

        Pi_grad = posterior.sum(0) - posterior.size(0)*Pi

        return None, B_grad, Pi_grad
Esempio n. 11
0
    def backward(ctx, log_likelihood):
        # Recovering saved tensors from forward
        x, tree = ctx.saved_input
        prior, beta, t_beta, A, B, Pi = ctx.saved_tensors

        # Getting model info
        C, n_gen, device = A.size(0), A.size(-1), A.device

        # Creating parameter gradient tensors
        A_grad, B_grad = torch.zeros_like(A), torch.zeros_like(B)

        # Downward recursion: init
        eps = torch.zeros((tree['dim'], C, n_gen), device=device)

        # Downward recursion: base case
        eps[tree['roots']] = beta[tree['roots']]

        for l in tree['levels']:
            # Computing eps_{u, pa(u)}(i, j)
            eps_pa = eps[l[0]].unsqueeze(1)
            beta_ch = beta[l[1]].unsqueeze(2)
            eps_trans_pa = A.unsqueeze(0) * eps_pa
            
            t_beta_ch = t_beta[l[1]].unsqueeze(1)
            prior_ch = prior[l[1]].unsqueeze(2)

            eps_joint = (beta_ch * eps_trans_pa) / (prior_ch * t_beta_ch) 

            # Computing eps_u(i)
            eps_ch = eps_joint.sum(2)
            eps[l[1]] = eps_ch

            A_grad += (eps_joint - eps_trans_pa).sum(0)

        eps_roots = eps[tree['roots']]
        Pi_grad = eps_roots.sum(0) - tree['roots'].size(0)*Pi
        
        eps_nodes = eps.permute(1, 0, 2)
        x_trees = x[tree['inv_map']]
        B_grad = scatter(eps_nodes - eps_nodes * B[:, x_trees],
                         index=x_trees,
                         dim=1,
                         out=B_grad)

        return None, None, A_grad, B_grad, Pi_grad
Esempio n. 12
0
    def forward(ctx, x, prev_h, edge_index, lambda_Q, lambda_B):
        Q, B = [], []
        
        for j in range(lambda_Q.size(-1)):
            Q.append(F.softmax(lambda_Q[:, :, j], dim=0))
            B.append(F.softmax(lambda_B[:, :, j], dim=1))

        Q, B = torch.stack(Q, dim=-1), torch.stack(B, dim=-1)

        prev_h_neigh = prev_h[edge_index[1]]
        prev_h_neigh_aggr = scatter(prev_h_neigh, edge_index[0], dim=0, reduce='mean')

        B_nodes = B[:, x].permute(1, 0, 2).unsqueeze(2)   # nodes x C x 1 x n_gen
        prev_h_neigh_aggr = prev_h_neigh_aggr.unsqueeze(1) # nodes x 1 x C x n_gen
        unnorm_posterior = B_nodes * (Q * prev_h_neigh_aggr) + 1e-12

        posterior_il = (unnorm_posterior / unnorm_posterior.sum([1, 2], keepdim=True)) # nodes x C x C x n_gen
        posterior_i = posterior_il.sum(2)

        ctx.saved_input = x, edge_index
        ctx.save_for_backward(posterior_il, posterior_i, Q, B)

        return unnorm_posterior.sum([1, 2]).log(), posterior_i